当前位置:   article > 正文

Conditional Gan代码学习_vaegan中的normal_init函数

vaegan中的normal_init函数

 自己修改了github的代码,选取了比较核心的代码,简化了相关操作

 基于MNIST数据集

  1. # 库的引用
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import numpy
  6. from torchvision import datasets
  7. from torch.utils.data import DataLoader
  8. from torchvision import transforms
  9. from torch.autograd import Variable
  10. from torchvision.utils import save_image
  11. # 用到的参数及设置
  12. batch_size = 128
  13. lr = 0.0002
  14. train_epoch = 50
  15. transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
  16. # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  17. # 准备数据
  18. data = datasets.MNIST(root='./data/mnist', train=True, transform=transform, download=True)
  19. data_loader = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=4)
  20. def normal_init(m, mean, std):
  21. if isinstance(m, nn.Linear):
  22. m.weight.data.normal_(mean, std)
  23. m.bias.data.zero_()
  24. class Generator(nn.Module):
  25. def __init__(self):
  26. super(Generator, self).__init__()
  27. self.fc1_1 = nn.Linear(100, 256) # 100是噪声的维度
  28. self.fc1_1_bn = nn.BatchNorm1d(256)
  29. self.fc1_2 = nn.Linear(10, 256)
  30. self.fc1_2_bn = nn.BatchNorm1d(256)
  31. self.fc2 = nn.Linear(512, 512)
  32. self.fc2_bn = nn.BatchNorm1d(512)
  33. self.fc3 = nn.Linear(512, 1024)
  34. self.fc3_bn = nn.BatchNorm1d(1024)
  35. self.fc4 = nn.Linear(1024, 784)
  36. # 循环时字典遍历的是key
  37. def weight_init(self, mean, std):
  38. for m in self._modules:
  39. normal_init(self._modules[m], mean, std)
  40. def forward(self, input, label):
  41. batch_size = input.size(0)
  42. x = F.relu(self.fc1_1_bn(self.fc1_1(input)))
  43. y = F.relu(self.fc1_2_bn(self.fc1_2(label)))
  44. x = torch.cat([x, y], 1) # 这样就变为512维了
  45. x = F.relu(self.fc2_bn(self.fc2(x)))
  46. x = F.relu(self.fc3_bn(self.fc3(x)))
  47. x = torch.tanh(self.fc4(x))
  48. # 可以reshape为图片的维度
  49. gen_img = x.view(batch_size, 1, 28, 28)
  50. return gen_img
  51. class Discriminator(nn.Module):
  52. def __init__(self):
  53. super(Discriminator, self).__init__()
  54. self.fc1_1 = nn.Linear(784, 1024)
  55. self.fc1_2 = nn.Linear(10, 1024)
  56. self.fc2 = nn.Linear(2048, 512)
  57. self.fc2_bn = nn.BatchNorm1d(512)
  58. self.fc3 = nn.Linear(512, 256)
  59. self.fc3_bn = nn.BatchNorm1d(256)
  60. self.fc4 = nn.Linear(256, 1)
  61. def weight_init(self, mean, std):
  62. for m in self._modules:
  63. normal_init(self._modules[m], mean, std)
  64. def forward(self, input, label):
  65. # Flatten操作,展平为784个像素后再操作
  66. batch_size = input.shape[0]
  67. x = input.view(batch_size, -1)
  68. x = F.leaky_relu(self.fc1_1(x), 0.2)
  69. y = F.leaky_relu(self.fc1_2(label), 0.2)
  70. x = torch.cat([x, y], 1) # 按列进行连接,变为2048
  71. x = F.leaky_relu(self.fc2_bn(self.fc2(x)), 0.2)
  72. x = F.leaky_relu(self.fc3_bn(self.fc3(x)), 0.2)
  73. x = torch.sigmoid(self.fc4(x))
  74. return x
  75. # network,结构可以参照图片
  76. G = Generator().cuda()
  77. D = Discriminator().cuda()
  78. G.weight_init(mean=0, std=0.02)
  79. D.weight_init(mean=0, std=0.02)
  80. criterion = nn.BCELoss()
  81. G_optimizer = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
  82. D_optimizer = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
  83. def train(epoch):
  84. # 随着训练进行, 让lr变小
  85. if (epoch+1) == 30:
  86. G_optimizer.param_groups[0]['lr'] /= 10
  87. D_optimizer.param_groups[0]['lr'] /= 10
  88. print("learning rate change!")
  89. if (epoch+1) == 40:
  90. G_optimizer.param_groups[0]['lr'] /= 10
  91. D_optimizer.param_groups[0]['lr'] /= 10
  92. print("learning rate change!")
  93. for ind, (images, labels) in enumerate(data_loader):
  94. # 随机噪声z,100
  95. z = Variable(torch.Tensor(numpy.random.normal(0, 1, (images.size(0), 100))))
  96. # train discriminator D
  97. mini_batch = images.size(0)
  98. D.zero_grad()
  99. #这是对应的二分类真与假的标签值
  100. y_real_ = torch.ones(mini_batch)
  101. y_fake_ = torch.zeros(mini_batch)
  102. # 这两句是生成label的one-hot向量,因为bceloss用的是one-hot
  103. y_label_ = torch.zeros(mini_batch, 10)
  104. y_label_.scatter_(1, labels.view(mini_batch, 1), 1)
  105. # (128,10)
  106. images, y_label_, y_real_, y_fake_ = Variable(images.cuda()), Variable(y_label_.cuda()), Variable(\
  107. y_real_.cuda()), Variable(y_fake_.cuda())
  108. D_result = D(images, y_label_).squeeze() # 这是告诉判别器标签和图片的配对情况
  109. # 这个地方suqeeze就是128, 不是就是(128,1), 不能BCEloss
  110. D_real_loss = criterion(D_result, y_real_) # Discriminator的一个loss
  111. y_ = (torch.rand(mini_batch, 1) * 10).type(torch.LongTensor)
  112. y_label_ = torch.zeros(mini_batch, 10)
  113. y_label_.scatter_(1, y_.view(mini_batch, 1), 1)
  114. z, y_label_ = Variable(z.cuda()), Variable(y_label_.cuda())
  115. G_result = G(z, y_label_)
  116. gen_imgs = G(z, y_label_)
  117. D_result = D(G_result, y_label_).squeeze()
  118. D_fake_loss = criterion(D_result, y_fake_) # Discriminator的假图片loss
  119. D_train_loss = D_real_loss + D_fake_loss
  120. D_train_loss.backward()
  121. D_optimizer.step()
  122. # train generator G
  123. G.zero_grad()
  124. # 杂乱的生成的标签
  125. z_ = torch.rand((mini_batch, 100))
  126. y_ = (torch.rand(mini_batch, 1) * 10).type(torch.LongTensor)
  127. y_label_ = torch.zeros(mini_batch, 10)
  128. y_label_.scatter_(1, y_.view(mini_batch, 1), 1)
  129. z_, y_label_ = Variable(z_.cuda()), Variable(y_label_.cuda())
  130. G_result = G(z_, y_label_)
  131. D_result = D(G_result, y_label_).squeeze()
  132. G_train_loss = criterion(D_result, y_real_)
  133. G_train_loss.backward()
  134. G_optimizer.step()
  135. if ind % 50 == 0:
  136. print('G_loss:', G_train_loss.data.item(), 'D_loss:', D_train_loss.data.item())
  137. save_image(gen_imgs.data[:25], 'images{}.png'.format(epoch), nrow=5, normalize=True)
  138. if __name__ == '__main__':
  139. for epoch in range(train_epoch):
  140. train(epoch)

 结果如图,效果还是不错的 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/41012?site
推荐阅读
相关标签
  

闽ICP备14008679号