当前位置:   article > 正文

利用keras框架搭建一个简单CGAN模型,制作一个数据集用作训练_使用cgan训练自己的数据集

使用cgan训练自己的数据集

目录

利用keras框架搭建一个CGAN模型是比较方便的,这里我就不多说什么了,直接上代码吧

        数据集的构建大致如下:

关于tag.txt:

关于tags.txt:

文本相似度的对比可以通过余弦相似度比较,例如以下代码:


        

        目前网络上搭建这种模型的文章已经很多了,在这里我会说一些在阅读其他作者的相关文章的同时进行搭建遇到的一系列问题以及数据集兼容的一点见解。

利用keras框架搭建一个CGAN模型是比较方便的,这里我就不多说什么了,直接上代码吧

  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. from keras.optimizers import Adam
  4. from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, LeakyReLU, Layer, BatchNormalization, \
  5. Embedding
  6. from keras.models import Model, Sequential
  7. import tensorflow as tf
  8. from tensorflow.python.keras.models import save_model
  9. from keras.utils import plot_model
  10. import myDataToMnist
  11. class Mycgan():
  12. def __init__(self):
  13. self.img_rows = 64 # 输入图像像素行
  14. self.img_cols = 64 # 输入图像像素列
  15. self.img_tag = 1 # 标签
  16. self.channels = 3 # 管道数量,彩色图像rgb三个管道
  17. self.img_shapes = (self.img_rows, self.img_cols, self.img_tag, self.channels) # 将三个图像形象组合成一个shape
  18. self.num_classes = 21 # 样本的数量
  19. self.latent_dim = 100 # 特征向量的维度
  20. optimizer = Adam(0.0002, 0.5) # 优化算法,自适应的梯度下降算法,参数为:学习速率,
  21. # 构建判别器
  22. self.discriminator = self.build_Discriminator()
  23. self.discriminator.compile(loss=['binary_crossentropy'], optimizer=optimizer, metrics=['accuracy'])
  24. # 构建生成器
  25. self.generator = self.build_Generator()
  26. # 生成器的输入定义(噪声+标签)
  27. noise = Input(shape=(self.latent_dim,), name='noise')
  28. label = Input(shape=(1,), name='label')
  29. img = self.generator([noise, label]) # 生成器生成图像的输入定义
  30. # 暂时仅训练生成器
  31. self.discriminator.trainable = True
  32. # 鉴别器将生成图像与标签作为输入判别图像真实性
  33. valid = self.discriminator([img, label])
  34. # 组合模型
  35. self.combined = Model([noise, label], valid, name='CGAN_Model')
  36. self.combined.compile(loss=['binary_crossentropy'], optimizer=optimizer)
  37. def build_Generator(self):
  38. # 堆叠模型,允许添加多个层构建深度神经网络
  39. model = Sequential()
  40. model.add(Dense(256, input_dim=self.latent_dim)) # 全连接层
  41. model.add(LeakyReLU(alpha=0.2)) # 带泄露的RLU层,赋予模型一个起始较小的梯度
  42. model.add(BatchNormalization(momentum=0.8)) # 批量标准化层,将输入张量的数据转换到0~1之间
  43. model.add(Dense(512))
  44. model.add(LeakyReLU(alpha=0.2))
  45. model.add(BatchNormalization(momentum=0.8))
  46. model.add(Dense(1024))
  47. model.add(LeakyReLU(alpha=0.2))
  48. model.add(BatchNormalization(momentum=0.8))
  49. model.add(Dense(np.prod(self.img_shapes), activation='tanh')) # 双曲正切激活函数
  50. model.add(Reshape(self.img_shapes)) # 将转换并计算后张量转换到指定尺寸(此处为原图像尺寸)
  51. model.summary() # 打印网络结构及各项参数
  52. noise = Input(shape=(self.latent_dim,))
  53. label = Input(shape=(1,), dtype='float32')
  54. label_embedding = Flatten()(Embedding(input_dim=self.num_classes, output_dim=self.latent_dim)(label))
  55. # label = Input(shape=(1,),dtype='string')
  56. # label_embedding = Flatten()(Embedding(self.num_classes,self.latent_dim)) #将参数联合转换为稠密张量,再展平为一维张量
  57. model_input = multiply([noise, label_embedding]) # 通过将噪声与标签逐一乘积合成一个张量作为整个生成器的输入
  58. img = model(model_input) # 生成图像
  59. save_model(Model([noise, label], img), 'generator.h5')
  60. return Model([noise, label], img, name='generator') # 将生成图像与合成的条件向量作为输出传入到判别器
  61. def build_Discriminator(self):
  62. model = Sequential()
  63. model.add(Dense(512, input_dim=np.prod(self.img_shapes)))
  64. model.add(LeakyReLU(alpha=0.2))
  65. model.add(Dense(512))
  66. model.add(LeakyReLU(alpha=0.2))
  67. model.add(Dropout(0.4)) # 防止过拟合
  68. model.add(Dense(512))
  69. model.add(LeakyReLU(alpha=0.2))
  70. model.add(Dropout(0.4))
  71. model.add(Dense(1, activation='sigmoid'))
  72. model.summary()
  73. img = Input(shape=self.img_shapes)
  74. label = Input(shape=(1,), dtype='int32')
  75. label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shapes))(label))
  76. flat_img = Flatten()(img)
  77. model_input = multiply([flat_img, label_embedding])
  78. validity = model(model_input)
  79. save_model(Model([img, label], validity), 'discriminator.h5')
  80. return Model([img, label], validity, name='discriminator')
  81. def train(self, epochs, batch_size=32, sample_interval=50):
  82. # 加载数据集
  83. (X_train, Y_train) = (myDataToMnist.x_train, myDataToMnist.y_train)
  84. X_train = (X_train.astype(np.float32) - 127.5) / 127.5
  85. X_train = np.expand_dims(X_train, axis=3)
  86. Y_train = Y_train.reshape(-1, 1)
  87. valid = np.ones([batch_size, 1])
  88. fake = np.zeros([batch_size, 1])
  89. for epochs in range(epochs):
  90. index = np.random.randint(0, X_train.shape[0], batch_size)
  91. imgs, labels = X_train[index], Y_train[index]
  92. noise = np.random.normal(0, 1, (batch_size, 100))
  93. gen_img = self.generator.predict([noise, labels])
  94. d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid)
  95. d_loss_fake = self.discriminator.train_on_batch([gen_img, labels], fake)
  96. d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
  97. sample_labels = np.random.randint(0, 20, batch_size).reshape(-1, 1)
  98. g_loss = self.combined.train_on_batch([noise, sample_labels], valid)
  99. print("%d [D 损失:%f, ACC.: %.2f%%] [G 损失: %f]" % (epochs, d_loss[0], 100 * d_loss[1], g_loss))
  100. if epochs % sample_interval == 0:
  101. self.sample_images(epochs)
  102. tf.keras.models.save_model(self.combined, 'my_cgan_model.h5')
  103. plot_model(self.combined, show_shapes=True, to_file='./model.png')
  104. def sample_images(self, epoch):
  105. r, c = 4, 5
  106. noise = np.random.normal(0, 1, (r * c, 100))
  107. sample_labels = np.arange(0, 20).reshape(-1, 1)
  108. gen_image = self.generator.predict([noise, sample_labels])
  109. gen_image = 0.5 * gen_image + 0.5
  110. fig, axs = plt.subplots(r, c)
  111. cnt = 0
  112. for i in range(r):
  113. for j in range(c):
  114. axs[i, j].imshow(gen_image[cnt, :, :, 0])
  115. axs[i, j].set_title("Digit: %d" % sample_labels[cnt])
  116. axs[i, j].axis('off')
  117. cnt += 1
  118. fig.savefig("images/%d.png" % epoch)
  119. plt.close()
  120. if __name__ == '__main__':
  121. cgan = Mycgan()
  122. cgan.train(epochs=1500, batch_size=128, sample_interval=200)

        需要注意的是,此处数据集不再是mnist灰度图像数据集,而是我自己创建的一个类似格式的彩色图像数据集。

        数据集的构建大致如下:

  1. import os
  2. import numpy as np
  3. import pandas as pd
  4. from PIL import Image
  5. # 定义数据集目录路径及相关信息
  6. data_dir = f"包含图像文件夹的文件夹"
  7. label_data_dir = f"标签文件夹名称"
  8. label_file = '标签,下面有解释'
  9. image_dir = f"图像文件夹名"
  10. label_dict = {}
  11. filename = '映射关联文本'
  12. # 取图像文件夹,从0开始将图像名称作为对应值键入
  13. with open(filename, 'r') as f:
  14. lines = f.readlines()
  15. for line in lines:
  16. key, value = line.strip().split(':')
  17. label_dict[key] = value
  18. print(label_dict.__len__())
  19. # 读取标签文件,构建 y_train 数组
  20. with open(os.path.join(label_data_dir, label_file), "r") as f:
  21. labels = f.readlines()
  22. y_train = np.array([int(label.strip()) for label in labels]) # 删除每个标签后面的换行符
  23. # 读取图像文件,构建 x_train 数组
  24. image_paths = [os.path.join(data_dir, image_dir, fname) for fname in os.listdir(os.path.join(data_dir, image_dir)) if
  25. fname.endswith(".jpg")]
  26. x_train = []
  27. for image_path in image_paths:
  28. with Image.open(image_path) as img:
  29. img = img.resize((64, 64))
  30. img_arr = np.array(img) # 将 PIL.Image 对象转换为 ndarray 数组
  31. x_train.append(img_arr)
  32. x_train = np.array(x_train)
  33. y_train_desc = []
  34. for y in y_train:
  35. y_train_desc.append(label_dict[str(y)])
  36. train_df = pd.DataFrame({"data": list(x_train), "label": y_train_desc})
  37. train_df.to_csv("train.csv", index=False)
  38. print(x_train.shape[1])

关于tag.txt:

 序号,长度为你要构建的图像集的图像数

关于tags.txt:

        映射关系,具体为于当前序号相关联的图像的名称

        大致就是这样,如果想通过控制台输入控制图像生成的话,可以将生成器中label封装一层文本相似度比较的程序,通过比较输入文本与数据集中映射文本的相似度选取最高相似度的文本对应的序号作为随机条件噪声输入,只需要把sample_labels = np.arrange(0,20).reshape(-1,1)改为np.array(输入的随机条件噪声值).reshape(-1,1)就行,当然本代码中的图像展示只是训练结果,如果需要测试结果必须通过读取保存好的模型.h5文件进行predict().

文本相似度的对比可以通过余弦相似度比较,例如以下代码:

  1. # 输入文本进行清洗
  2. import jieba
  3. import myDataToMnist
  4. from sklearn.feature_extraction.text import TfidfVectorizer
  5. from sklearn.metrics.pairwise import cosine_similarity
  6. def text_compare():
  7. # 1.对文本进行分词
  8. print("请输入一些描述词,类似“紫色的花”:")
  9. text1 = input()
  10. len1 = len(text1)
  11. text2 = ""
  12. sim = 0.0
  13. label = ""
  14. for tex2 in myDataToMnist.y_train_desc:
  15. len2 = len(tex2)
  16. if len1 > len2:
  17. text2 += ' ' * (len1 - len2)
  18. else:
  19. text1 += ' ' * (len1 - len2)
  20. sen1 = jieba.lcut(text1)
  21. sen2 = jieba.lcut(tex2)
  22. seg_str1 = ",".join(sen1)
  23. seg_str2 = ",".join(sen2)
  24. vectorizer = TfidfVectorizer()
  25. vectorizer.fit([seg_str1,seg_str2])
  26. vector1,vector2 = vectorizer.transform([seg_str1,seg_str2])
  27. similarity = cosine_similarity(vector1,vector2)[0][0]
  28. if similarity > sim:
  29. sim = similarity
  30. label = tex2
  31. print("文本相似度:",sim)
  32. print("相似文本:",label)
  33. print("噪声编号:", myDataToMnist.y_train_desc.index(label))
  34. # 通过值找到键
  35. def get_key(key, value):
  36. return list(key.keys())[list(key.values()).index(value)]
  37. text_compare()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/42280
推荐阅读
相关标签
  

闽ICP备14008679号