当前位置:   article > 正文

pytorch实战-图像生成与对抗

pytorch实战-图像生成与对抗

1 概述

what:给定一句话,或一些要求,按要求生成需要的图像。

本篇总结主要包含反卷积和GAN(generative adversial network, GAN)

2 反卷积与图像生成

what:反卷积可以看成卷积的反操作,但不完全一样,不是把卷积反过来就是反卷积。即给定特征,反向生成输入。但反卷积运算的卷积核与卷积运算的不同

效果:卷积是大图像越来越小,反卷积可以图像越来越大

2.1 反卷积运算

卷积核不同:卷积卷积核旋转180度可得到反卷积运算的卷积核

padding:如果希望反卷积运算后,图像大小保持不变,需要计算padding并给输入图像补padding

2.2 反池化运算

反池化有很多方法,有一种卷积运算方法可以近似省略池化(因为效果相近),即给卷积运算加步伐。即每一个卷积核在原图像运算完,朝下一个运算窗口移动的步数。默认步数是1.步数大于1的效果很接近卷积+池化运算效果。这样的卷积运算,可以看成步数为1的卷积运算+池化运算,即省略了池化运算

步伐>2的卷积效果:卷积得到的图像比步伐小的图像更小。因此反卷积时,也需要处理此种情况

2.3 反卷积和分数步伐

步伐>2的卷积,可以通过分数步伐的反卷积恢复。即对输入图像每个像素点之间补充空白点,卷积步长越大,反卷积补的像素间空白点就越多

2.4 批正则化技术

概念:是每一层神经网络层和非线性运算层之间加入的一个线性运算层,逻辑为y=ax+b。a,b为要学习的参数,x为一批里归一化处理后的输入:(x-mean(x))/std

3 图像生成-最小均方差模型

3.1 思路

输入是一个数字,输出是一个数字的手写图像。通过反卷积网络实现这样的输入与输出

3.2 代码实现

  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import torch.nn.functional as F
  5. import torchvision.datasets as datasets
  6. import torchvision.transforms as transforms
  7. import torchvision.utils as util
  8. import matplotlib.pyplot as pyplot
  9. import numpy as np
  10. import os
  11. output_img_size = 28
  12. input_dim = 100
  13. channel_num = 1
  14. features_num = 64
  15. batch_size = 64
  16. print(f'prepare datasets begin')
  17. use_cuda = torch.cuda.is_available()
  18. dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
  19. itype = torch.cuda.LongTensor if use_cuda else torch.LongTensor
  20. train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
  21. test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())
  22. train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
  23. index_verify = range(len(test_dataset))[:5000]
  24. index_test = range(len(test_dataset))[5000:]
  25. sampler_verify = torch.utils.data.sampler.SubsetRandomSampler(index_verify)
  26. sampler_test = torch.utils.data.sampler.SubsetRandomSampler(index_test)
  27. verify_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, sampler=sampler_verify)
  28. test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, sampler=sampler_test)
  29. class AntiCNN(nn.Module):
  30. def __init__(self):
  31. super(AntiCNN, self).__init__()
  32. self.model = nn.Sequential()
  33. self.model.add_module('deconv1', nn.ConvTranspose2d(input_dim, features_num * 2, 5, 2, 0, bias=False))
  34. self.model.add_module('batch_norm1', nn.BatchNorm2d(features_num * 2))
  35. self.model.add_module('relu1', nn.ReLU(True))
  36. self.model.add_module('deconv2', nn.ConvTranspose2d(features_num * 2, features_num, 5, 2, 0, bias=False))
  37. self.model.add_module('batch_norm2', nn.BatchNorm2d(features_num))
  38. self.model.add_module('relu2', nn.ReLU(True))
  39. self.model.add_module('deconv3', nn.ConvTranspose2d(features_num, channel_num, 4, 2, 0, bias=False))
  40. self.model.add_module('sigmoid', nn.Sigmoid())
  41. def forward(self, input):
  42. output = input
  43. for _, module in self.model.named_children():
  44. output = module(output)
  45. return output
  46. def weight_init(module):
  47. class_name = module.__class__.__name__
  48. if class_name.find('conv') != -1:
  49. module.weight.data.normal_(0, 0.02) # convey mean and std
  50. if class_name.find('norm') != -1:
  51. module.weight.data.normal_(1, 0.02)
  52. def resize_to_img(img):
  53. return img.data.expand(batch_size, 3, output_img_size, output_img_size)
  54. def imgshow(input, title=None):
  55. if input.size()[0] > 1:
  56. input = input.numpy().transpose((1, 2, 0))
  57. else:
  58. input = input[0].numpy()
  59. min_val, max_val = np.amin(input), np.amax(input)
  60. if max_val > min_val:
  61. input = (input - min_val) / (max_val - min_val)
  62. pyplot.imshow(input)
  63. if title:
  64. pyplot.title(title)
  65. pyplot.pause(0.001)
  66. def main():
  67. net = AntiCNN()
  68. net = net.cuda() if use_cuda else net
  69. criterion = nn.MSELoss()
  70. optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)
  71. samples = np.random.choice(10, batch_size)
  72. samples = torch.from_numpy(samples).type(dtype)
  73. step = 0
  74. num_epoch = 2
  75. record = []
  76. print('train begin')
  77. for epoch in range(num_epoch):
  78. print(f'the no.{epoch} epoch')
  79. train_loss = []
  80. for batch_index, (data, target) in enumerate(train_loader):
  81. target, data = data.clone().detach().requires_grad_(True), target.clone().detach()
  82. #target, data = target.cuda(), data.cuda() if use_cuda else target, data
  83. if use_cuda:
  84. target, data = target.cuda(), data.cuda()
  85. data = data.type(dtype)
  86. data = data.resize(data.size()[0], 1, 1, 1)
  87. data = data.expand(data.size()[0], input_dim, 1, 1)
  88. net.train()
  89. output = net(data)
  90. loss = criterion(output, target)
  91. optimizer.zero_grad()
  92. loss.backward()
  93. optimizer.step()
  94. step += 1
  95. loss = loss.cpu() if use_cuda else loss
  96. train_loss.append(loss.data.numpy())
  97. if batch_index % 300 == 0:
  98. net.eval()
  99. verify_loss = []
  100. index = 0
  101. for data, target in verify_loader:
  102. target, data = data.clone().detach().requires_grad_(True), target.clone().detach()
  103. index += 1
  104. # target, data = target.cuda(), data.cuda() if use_cuda else target, data
  105. if use_cuda:
  106. target, data = target.cuda(), data.cuda()
  107. data = data.type(dtype)
  108. data = data.resize(data.size()[0], 1, 1, 1)
  109. data = data.expand(data.size()[0], input_dim, 1, 1)
  110. output = net(data)
  111. loss = criterion(output, target)
  112. loss = loss.cpu() if use_cuda else loss
  113. verify_loss.append(loss.data.numpy())
  114. print(f'now no.{batch_index} batch. train loss:{np.mean(train_loss):.4f}, verify loss:{np.mean(verify_loss):.4f}')
  115. record.append([np.mean(train_loss), np.mean(verify_loss)])
  116. with torch.no_grad():
  117. samples.resize_(batch_size, 1, 1, 1)
  118. samples = samples.data.expand(batch_size, input_dim, 1, 1)
  119. # samples = samples.cuda() if use_cuda else samples
  120. if use_cuda:
  121. samples = samples.cuda()
  122. fake_u = net(samples)
  123. # fake_u = fake_u.cuda() if use_cuda else fake_u
  124. if use_cuda:
  125. fake_u = fake_u.cuda()
  126. img = resize_to_img(fake_u)
  127. os.makedirs(os.path.realpath('./pytorch/jizhi/image_generate/temp1'), exist_ok=True)
  128. util.save_image(img, os.path.realpath(f'./pytorch/jizhi/image_generate/temp1/fake{epoch}.png'))
  129. pyplot.show()
  130. if __name__ == '__main__':
  131. main()

发现图片很模糊,可能是均方误差算的是所有手写数字的平均值,且每个图像没有明显模式,倒是平均值就是很模糊。咋整呢?可以尝试用之前的手写数字图像识别器帮助矫正MSE

4 图像生成-生成器-识别器模型

5 图像生成-GAN

6 小结

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/47258?site
推荐阅读
相关标签
  

闽ICP备14008679号