当前位置:   article > 正文

Python !AI绘画_pythonai绘画

pythonai绘画
  1. import tensorflow as tf
  2. from tensorflow import keras
  3. from tensorflow.keras import layers
  4. # 定义生成器模型
  5. def build_generator(latent_dim):
  6. model = keras.Sequential(
  7. [
  8. layers.Dense(256, input_dim=latent_dim),
  9. layers.LeakyReLU(alpha=0.2),
  10. layers.BatchNormalization(),
  11. layers.Dense(512),
  12. layers.LeakyReLU(alpha=0.2),
  13. layers.BatchNormalization(),
  14. layers.Dense(1024),
  15. layers.LeakyReLU(alpha=0.2),
  16. layers.BatchNormalization(),
  17. layers.Dense(28 * 28, activation="tanh"),
  18. layers.Reshape((28, 28, 1)),
  19. ]
  20. )
  21. return model
  22. # 定义鉴别器模型
  23. def build_discriminator(img_shape):
  24. model = keras.Sequential(
  25. [
  26. layers.Flatten(input_shape=img_shape),
  27. layers.Dense(512),
  28. layers.LeakyReLU(alpha=0.2),
  29. layers.Dense(256),
  30. layers.LeakyReLU(alpha=0.2),
  31. layers.Dense(1, activation="sigmoid"),
  32. ]
  33. )
  34. return model
  35. # 定义GAN模型
  36. def build_gan(generator, discriminator):
  37. discriminator.trainable = False
  38. model = keras.Sequential([generator, discriminator])
  39. return model
  40. # 加载MNIST数据集
  41. (x_train, _), (_, _) = keras.datasets.mnist.load_data()
  42. x_train = x_train.reshape(-1, 28, 28, 1).astype("float32") / 127.5 - 1.0
  43. dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(buffer_size=1024).batch(32)
  44. # 创建生成器和鉴别器
  45. latent_dim = 128
  46. generator = build_generator(latent_dim)
  47. discriminator = build_discriminator(x_train[0].shape)
  48. # 定义优化器和损失函数
  49. loss_fn = keras.losses.BinaryCrossentropy()
  50. generator_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
  51. discriminator_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
  52. # 训练GAN模型
  53. epochs = 50
  54. for epoch in range(epochs):
  55. for real_images in dataset:
  56. # 训练鉴别器
  57. noise = tf.random.normal(shape=(32, latent_dim))
  58. fake_images = generator(noise)
  59. real_labels = tf.ones((32, 1))
  60. fake_labels = tf.zeros((32, 1))
  61. with tf.GradientTape() as tape:
  62. real_loss = loss_fn(real_labels, discriminator(real_images))
  63. fake_loss = loss_fn(fake_labels, discriminator(fake_images))
  64. total_loss = real_loss + fake_loss
  65. grads = tape.gradient(total_loss, discriminator.trainable_weights)
  66. discriminator_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))
  67. # 训练生成器
  68. noise = tf.random.normal(shape=(32, latent_dim))
  69. with tf.GradientTape() as tape:
  70. fake_images = generator(noise)
  71. fake_loss = loss_fn(real_labels, discriminator(fake_images))
  72. grads = tape.gradient(fake_loss, generator.trainable_weights)
  73. generator_optimizer.apply_gradients(zip(grads, generator.trainable_weights))

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/380409
推荐阅读
相关标签
  

闽ICP备14008679号