赞
踩
- import tensorflow as tf
- from tensorflow import keras
- from tensorflow.keras import layers
-
- # 定义生成器模型
- def build_generator(latent_dim):
- model = keras.Sequential(
- [
- layers.Dense(256, input_dim=latent_dim),
- layers.LeakyReLU(alpha=0.2),
- layers.BatchNormalization(),
- layers.Dense(512),
- layers.LeakyReLU(alpha=0.2),
- layers.BatchNormalization(),
- layers.Dense(1024),
- layers.LeakyReLU(alpha=0.2),
- layers.BatchNormalization(),
- layers.Dense(28 * 28, activation="tanh"),
- layers.Reshape((28, 28, 1)),
- ]
- )
- return model
-
- # 定义鉴别器模型
- def build_discriminator(img_shape):
- model = keras.Sequential(
- [
- layers.Flatten(input_shape=img_shape),
- layers.Dense(512),
- layers.LeakyReLU(alpha=0.2),
- layers.Dense(256),
- layers.LeakyReLU(alpha=0.2),
- layers.Dense(1, activation="sigmoid"),
- ]
- )
- return model
-
- # 定义GAN模型
- def build_gan(generator, discriminator):
- discriminator.trainable = False
- model = keras.Sequential([generator, discriminator])
- return model
-
- # 加载MNIST数据集
- (x_train, _), (_, _) = keras.datasets.mnist.load_data()
- x_train = x_train.reshape(-1, 28, 28, 1).astype("float32") / 127.5 - 1.0
- dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(buffer_size=1024).batch(32)
-
- # 创建生成器和鉴别器
- latent_dim = 128
- generator = build_generator(latent_dim)
- discriminator = build_discriminator(x_train[0].shape)
-
- # 定义优化器和损失函数
- loss_fn = keras.losses.BinaryCrossentropy()
- generator_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
- discriminator_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
-
- # 训练GAN模型
- epochs = 50
- for epoch in range(epochs):
- for real_images in dataset:
- # 训练鉴别器
- noise = tf.random.normal(shape=(32, latent_dim))
- fake_images = generator(noise)
- real_labels = tf.ones((32, 1))
- fake_labels = tf.zeros((32, 1))
- with tf.GradientTape() as tape:
- real_loss = loss_fn(real_labels, discriminator(real_images))
- fake_loss = loss_fn(fake_labels, discriminator(fake_images))
- total_loss = real_loss + fake_loss
- grads = tape.gradient(total_loss, discriminator.trainable_weights)
- discriminator_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))
-
- # 训练生成器
- noise = tf.random.normal(shape=(32, latent_dim))
- with tf.GradientTape() as tape:
- fake_images = generator(noise)
- fake_loss = loss_fn(real_labels, discriminator(fake_images))
- grads = tape.gradient(fake_loss, generator.trainable_weights)
- generator_optimizer.apply_gradients(zip(grads, generator.trainable_weights))
-

Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。