当前位置:   article > 正文

keras实现GAN网络的代码详解_代码noise z -> generator g -> generated samples -> d

代码noise z -> generator g -> generated samples -> discriminator d -> probab

目录

 导入模块

产生真实样本x_sample() & 噪声样本z_sample()

定义生成器generator_model

可视化神经网络的函数 plot_model

定义判别器discriminator_model

定义一个含有判别器D的生成器generator_containing_discriminator(g, d)

绘制直方图show_image                 保存图片save_image

 初始化show_init                           保存loss值

 主函数main                                 参考



 导入模块

注:本文讲解的GAN网络源码在这里  ,所以模块导入就不再讲解,推荐看mnist数据集实现keras_gan.py

产生真实样本x_sample() & 噪声样本z_sample()

  1. mu, sigma=(0,1)
  2. #真实样本满足正态分布 均值mu为0 方差sigma为1 样本维度200,分为32个批量
  3. def x_sample(size=200,batch_size=32):
  4. x=[]
  5. for _ in range(batch_size):
  6. x.append(np.random.normal(mu, sigma, size))
  7. return np.array(x)
  8. #噪声样本 噪声维度为200, 服从[-1,1]上的均匀分布uniform
  9. def z_sample(size=200,batch_size=32):
  10. z=[]
  11. for _ in range(batch_size):
  12. z.append(np.random.uniform(-1, 1, size))
  13. return np.array(z)

定义生成器generator_model

 Dense是全连接层,所实现的运算是output = activation(dot(input, kernel)+bias)  【参考keras中文文档 :Dense层

plot_model是可视化神经网络的函数,并可以将可视化结果保存在本地图片

  1. def generator_model():
  2. model = Sequential()
  3. model.add(Dense(input_dim=200, units=256)) # units表示该层的输出维度
  4. model.add(Activation('relu')) #激活函数
  5. model.add(Dense(200)) #Dense是全连接层
  6. model.add(Activation('sigmoid'))
  7. plot_model(model, show_shapes=True, to_file='gan/keras-gan-generator_model.png')
  8. return model

可视化神经网络的函数 plot_model

plot(model, to_file='model1.png',show_shapes=True)   

生成一个模型图,第一个参数为模型,第二个为要生成图片的路径及文件名,还可以指定两个参数:

  • show_shapes:指定是否显示输出数据的形状,默认为False 
  • show_layer_names:指定是否显示层名称,默认为True

show_shapes为True / False 的 具体显示差别请参考 【keras的模型可视化


定义判别器discriminator_model

Reshape层用来将input_shape转换为特定的shape (200,), Dense是全连接层。

  1. def discriminator_model():
  2. model = Sequential()
  3. model.add(Reshape((200,), input_shape=(200,)))
  4. model.add(Dense(units=256)) # 有256个神经元的全连接层,输出维度就为256
  5. model.add(Activation('relu'))
  6. model.add(Dense(1))
  7. model.add(Activation('sigmoid'))
  8. plot_model(model, show_shapes=True, to_file='gan/keras-gan-discriminator_model.png')
  9. return model

定义一个含有判别器D的生成器generator_containing_discriminator(g, d)

    d.trainable = False    #就是不更新判别器D的权值

  1. def generator_containing_discriminator(g, d):
  2. model = Sequential()
  3. model.add(g)
  4. d.trainable = False
  5. model.add(d)
  6. plot_model(model, show_shapes=True, to_file='gan/keras-gan-gan_model.png')
  7. return model

 


绘制直方图show_image

count 为【?】,bins为分箱【直方图】的数目;

  1. def show_image(s):
  2. count, bins, ignored = plt.hist(s, 5, normed=True) #根据数据s绘制有5个bar的直方图
  3. plt.plot(bins, np.ones_like(bins), linewidth=2, color='r') #按照 bins 的shape创建数组,线条宽度为2 ,颜色为red
  4. plt.show()

 保存图片save_image

导入matplotlib.mlab as MLA,

先绘制一个直方图,normed 为True是频率图,直方图颜色facecolor为白色,直方图边框颜色edgecolor为蓝色。

绘制一条正态概率密度函数 ( normpdf)  曲线;y = normpdf(bins,mu,sigma) 中,mu为均值; sigma:标准差 ; y是正态概率密度函数在bins处的值

plt.savefig()保存图片

  1. def save_image(s,filename):
  2. count, bins, ignored = plt.hist(s, bins=20, normed=True,facecolor='w',edgecolor='b')
  3. y = MLA.normpdf(bins, mu, sigma)
  4. l = plt.plot(bins, y, 'g--', linewidth=2) #绿色虚线,线条宽度为2
  5. #plt.plot(bins, np.ones_like(bins), linewidth=2, color='r')
  6. plt.savefig(filename) #将绘制的图l保存为filename
  7. #plt.show()在savefig()之后,保存的图片才不会是空白

 初始化show_init

按照之前定义的x_sample() 函数来产生真实样本x,z_sample() 函数产生噪声样本z,并保存为图片。

x是个numpy数组,1个batch,样本维度为200,服从N(0,1)分布,

噪声样本z, 维度为200, 服从[-1,1]上的均匀分布uniform

  1. def show_init():
  2. x=x_sample(batch_size=1)[0]
  3. save_image(x,"gan/x-0.png")
  4. z=z_sample(batch_size=1)[0]
  5. save_image(z, "gan/z-0.png")

 保存loss值

需要保存判别器D的损失 d_loss_list,和生成器G的损g_loss_list.作图后保存为gan/loss.png。

可用plt.figure(figsize=(10, 5)) # 更改图片大小
plt.subplots_adjust(left=0.09,right=1,wspace=0.25,hspace=0.25,bottom=0.13,top=0.91) #调节子图

  1. def save_loss(d_loss_list,g_loss_list):
  2. plt.subplot(2, 1, 1) # 面板设置成2行1列,并取第一个(顺时针编号)
  3. plt.plot(d_loss_list, 'yo-') # 画图,染色
  4. #plt.title('A tale of 2 subplots')
  5. plt.ylabel('d_loss')
  6. plt.subplot(2, 1, 2) # 面板设置成2行1列,并取第二个(顺时针编号)
  7. plt.plot(g_loss_list,'r.-') # 画图,染色
  8. #plt.xlabel('time (s)')
  9. plt.ylabel('g_loss')
  10. plt.savefig("gan/loss.png")

 


 主函数main

先用show_init产生样本,进行初始化, lr为学习率,momentum为动量项,使用nesterov动量;g和d的优化方法一样

  1. if __name__ == '__main__':
  2. show_init()
  3. d_loss_list=[]
  4. g_loss_list = []
  5. batch_size=128
  6. d = discriminator_model()
  7. g = generator_model()
  8. d_on_g = generator_containing_discriminator(g, d)
  9. d_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
  10. g_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
  11. g.compile(loss='binary_crossentropy', optimizer="SGD") #二分类,损失函数用交叉熵
  12. d_on_g.compile(loss='binary_crossentropy', optimizer=g_optim)
  13. d.trainable = True #开始模型训练
  14. d.compile(loss='binary_crossentropy', optimizer=d_optim)
  15. for epoch in range(500): #训练500轮
  16. print("Epoch is", epoch)
  17. noise=z_sample(batch_size=batch_size) #噪声样本
  18. image_batch=x_sample(batch_size=batch_size)
  19. generated_images = g.predict(noise, verbose=0) #g网络的预测结果,verbose=0则控制台没有任何输出
  20. x= np.concatenate((image_batch, generated_images)) #数组拼接
  21. y=[1]*batch_size+[0]*batch_size
  22. d_loss = d.train_on_batch(x, y)
  23. print("d_loss : %f" % (d_loss))
  24. noise = z_sample(batch_size=batch_size)
  25. d.trainable = False
  26. g_loss = d_on_g.train_on_batch(noise, [1]*batch_size)
  27. d.trainable = True
  28. print("g_loss : %f" % (g_loss))
  29. d_loss_list.append(d_loss)
  30. g_loss_list.append(g_loss)
  31. if epoch % 100 == 1:
  32. # 测试阶段
  33. noise = z_sample(batch_size=1)
  34. generated_images = g.predict(noise, verbose=0)
  35. # print generated_images
  36. save_image(generated_images[0], "gan/z-{}.png".format(epoch))
  37. save_loss(d_loss_list, g_loss_list)

 


参考

【1】keras中文文档 :常用层

【2】keras 实现GAN(生成对抗网络)

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/一键难忘520/article/detail/744672
推荐阅读
相关标签
  

闽ICP备14008679号