当前位置:   article > 正文

CVAE(条件自编码) Condition GAN (条件GAN) 和 VAE-GAN模型之间的区别之CVAE_vae与cvae区别

vae与cvae区别

# 使用CVAE(条件自编码) 训练fashion-mnist数据集

  1. import os
  2. import time
  3. import tensorflow as tf
  4. import numpy as np
  5. from ops import *
  6. from utils import *
  7. class CVAE(object):
  8. model_name = "CVAE" # name for checkpoint
  9. def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir):
  10. self.sess = sess
  11. self.dataset_name = dataset_name
  12. self.checkpoint_dir = checkpoint_dir
  13. self.result_dir = result_dir
  14. self.log_dir = log_dir
  15. self.epoch = epoch
  16. self.batch_size = batch_size
  17. self.mean = 0
  18. self.var =1
  19. if dataset_name == 'mnist' or dataset_name == 'fashion-mnist':
  20. # parameters
  21. self.input_height = 28
  22. self.input_width = 28
  23. self.output_height = 28
  24. self.output_width = 28
  25. self.z_dim = z_dim # dimension of noise-vector
  26. self.y_dim = 10 # dimension of condition-vector (label)
  27. self.c_dim = 1
  28. # train
  29. self.learning_rate = 0.0002
  30. self.beta1 = 0.5
  31. # test
  32. self.sample_num = 64 # number of generated images to be saved
  33. # load mnist
  34. self.data_X, self.data_y = load_mnist(self.dataset_name)
  35. # get number of batches for a single epoch
  36. self.num_batches = len(self.data_X) // self.batch_size
  37. else:
  38. print("********there is no other dataset to do *********")
  39. raise NotImplementedError
  40. # 编码器中输入的是真实图像和噪音向量
  41. def encoder(self, x, y, is_training=True, reuse=False):
  42. with tf.variable_scope("encoder", reuse=reuse):
  43. y = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
  44. x = conv_cond_concat(x, y)
  45. net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='en_conv1'))
  46. net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='en_conv2'), is_training=is_training, scope='en_bn2'))
  47. net = tf.reshape(net, [self.batch_size, -1])
  48. net = lrelu(bn(linear(net, 1024, scope='en_fc3'), is_training=is_training, scope='en_bn3'))
  49. gaussian_params = linear(net, 2 * self.z_dim, scope='en_fc4')
  50. mean = gaussian_params[:, :self.z_dim]
  51. stddev = tf.nn.softplus(gaussian_params[:, self.z_dim:])
  52. return mean, stddev
  53. # 定义解码器的相关操作
  54. def decoder(self, z, y, is_training=True, reuse=False):
  55. with tf.variable_scope("decoder", reuse=reuse):
  56. z = concat([z, y], 1)
  57. net = tf.nn.relu(bn(linear(z, 1024, scope='de_fc1'), is_training=is_training, scope='de_bn1'))
  58. net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='de_fc2'), is_training=is_training, scope='de_bn2'))
  59. net = tf.reshape(net, [self.batch_size, 7, 7, 128])
  60. net = tf.nn.relu(
  61. bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='de_dc3'), is_training=is_training,
  62. scope='de_bn3'))
  63. out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='de_dc4'))
  64. return out
  65. def build_model(self):
  66. image_dims = [self.input_height, self.input_width, self.c_dim]
  67. bs = self.batch_size
  68. self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images')
  69. self.y = tf.placeholder(tf.float32, [bs, self.y_dim], name='y')
  70. self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z')
  71. # 编码器中返回的数据是均值和方差 经过运算之后返回数据
  72. mu, sigma = self.encoder(self.inputs, self.y, is_training=True, reuse=False)
  73. z = mu + sigma * tf.random_normal(tf.shape(mu), 0, 1, dtype=tf.float32)
  74. # 解码器输出真实图像
  75. self.out = self.decoder(z, self.y, is_training=True, reuse=False)
  76. # 定义loss函数
  77. marginal_likelihood = tf.reduce_sum(self.inputs * tf.log(self.out) + (1 - self.inputs) * tf.log(1 - self.out), [1, 2])
  78. KL_divergence = 0.5 * tf.reduce_sum(tf.square(mu) + tf.square(sigma) - tf.log(1e-8 + tf.square(sigma)) - 1, [1])
  79. self.neg_loglikelihood = -tf.reduce_mean(marginal_likelihood)
  80. self.KL_divergence = tf.reduce_mean(KL_divergence)
  81. # 这个损失函数不是很懂 生成结果好 就这样用吧
  82. self.loss = self.neg_loglikelihood + self.KL_divergence
  83. # 定义优化器
  84. t_vars = tf.trainable_variables()
  85. with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
  86. self.optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \
  87. .minimize(self.loss, var_list=t_vars)
  88. # is_training设置为false生成图像 但是不参与模型修改参数
  89. self.fake_images = self.decoder(self.z, self.y, is_training=False, reuse=True)
  90. self.merged_summary_op = tf.summary.merge_all()
  91. def train(self):
  92. tf.global_variables_initializer().run()
  93. # 标签使用的是前batch_size个图像
  94. self.sample_z = np.random.normal(self.mean, self.var, (self.batch_size, self.z_dim)).astype(np.float32)
  95. self.test_labels = self.data_y[0:self.batch_size]
  96. start_epoch = 0
  97. start_batch_id = 0
  98. counter = 1
  99. start_time = time.time()
  100. for epoch in range(start_epoch, self.epoch):
  101. for idx in range(start_batch_id, self.num_batches):
  102. batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size]
  103. batch_labels = self.data_y[idx * self.batch_size:(idx + 1) * self.batch_size]
  104. batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32)
  105. _, summary_str, loss, nll_loss, kl_loss = self.sess.run([self.optim, self.merged_summary_op, self.loss, self.neg_loglikelihood, self.KL_divergence],
  106. feed_dict={self.inputs: batch_images, self.y: batch_labels, self.z: batch_z})
  107. counter += 1
  108. print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.8f, nll: %.8f, kl: %.8f" \
  109. % (epoch, idx, self.num_batches, time.time() - start_time, loss, nll_loss, kl_loss))
  110. # save training results for every 300 steps
  111. if np.mod(counter, 300) == 0:
  112. samples = self.sess.run(self.fake_images,
  113. feed_dict={self.z: self.sample_z, self.y: self.test_labels})
  114. tot_num_samples = min(self.sample_num, self.batch_size)
  115. manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
  116. manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
  117. save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],
  118. './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
  119. epoch, idx))
  120. start_batch_id = 0
  121. @property
  122. def model_dir(self):
  123. return "{}_{}_{}_{}".format(
  124. self.model_name, self.dataset_name,
  125. self.batch_size, self.z_dim)

训练的结果:

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/41015
推荐阅读
  

闽ICP备14008679号