当前位置:   article > 正文

【图像生成】(四) Diffusion原理 & pytorch代码实例_stable diffusion model 原理 源码 文本生成图像 vae

stable diffusion model 原理 源码 文本生成图像 vae

1.简介

之前介绍完了图像生成网络GAN和VAE,终于来到了Diffusion。stable diffusion里比较复杂,同时用到了diffusion,VAE,CLIP等模型,这里我们主要着重介绍diffusion网络本身。


2.原理

Diffusion扩散模型从字面上来理解,就是对噪声进行扩散。它一共有两个扩散步骤:

  • 正向扩散:根据预先设定的噪声进度在图像中添加高斯噪声,直到数据分布趋于先验分布
  • 反向扩撒:去除图像中的噪声,本质上是学习逐步恢复原数据分布

 正向扩散过程很好理解,每次step都在之前step的图像基础上加上随机的高斯噪声,这样经过多个step之后,图像将会变成完全的一个噪声图像。

反向扩散过程其实就是用UNet网络去预测逆向的高斯噪声,从而使图像去噪。在噪声微小的前提下,逆向的去噪过程也可以等同于预测高斯噪声。

根据马尔可夫链式推导法则,用x_{t}表示第t个step的图像,p(x_{t})表示该数据的概率分布,所以前向扩散方程可以表示为:

q(x_{1},...,x_{T}|x_{0})=\prod_{t=1}^{T}p(x_{t}|x_{t-1})

p(x_{t}|x_{t-1})=\mathcal N(x_{t};\sqrt{1-\beta _{t}}x_{t-1},\beta _{t}I)

其中,\beta _{t}表示第t个step的噪声系数。在第二个公式中,x_{t}为高斯函数的输出,x_{t-1}为高斯函数的输入,而\sqrt{1-\beta _{t}}为高斯函数的均值,\beta _{t}为高斯函数的方差。换言之,第t个step的图像x_{t}可以从第t-1个step的图像x_{t-1}再加上一个均值为\sqrt{1-\beta _{t}},方差为\beta _{t}的高斯噪声得到。

为什么均值和方差要设置为\sqrt{1-\beta _{t}}\beta _{t},这一切都是为了后面的参数重整化技巧。

因为在训练中,要对图像加不同step的噪声,由此让网络学习到不同噪声程度的数据。然而在实际训练中,当需要得到step较大的加噪图像时,我们不可能每次都从step=0开始重新加噪,这样时间成本太大。同时在上一篇文章VAE中我们可以知道,必须要将随机数限制在正态函数中,不然没法去反向推导梯度,因此采用了高斯函数被分解为特定均值和方差的正态函数这一方法。在扩散模型中,我们需要对多个高斯噪声进行叠加,有没有一种方法可以把叠加的高斯函数也分解为特定均值和方差的正态函数呢?

参数重整化:高斯函数可以被分解为特定均值和方差的正态函数,公式可以表达为:

if z \sim \mathcal N (\mu ,\sigma ^{2})  then

z=\mu +\sigma \varepsilon  where \varepsilon =\mathcal N(0,1)

因此,第t个step的图像可以表示为:

\begin{aligned} x_{t}&=\sqrt{1-\beta _{t}}x_{t-1}+\sqrt{\beta _{t}}\epsilon_{t-1} \\ &=\sqrt{\alpha _{t}}x_{t-1}+\sqrt{1-\alpha _{t}}\epsilon_{t-1}\\ &=\sqrt{\alpha _{t}}(\sqrt{\alpha _{t-1}}x_{t-2}+\sqrt{1-\alpha _{t-1}}\epsilon_{t-2})+\sqrt{1-\alpha _{t-1}}\epsilon_{t-1}\\ &= \sqrt{\alpha _{t}\alpha _{t-1}}x_{t-2}+\sqrt{\alpha_{t}(1-\alpha _{t-1})}\epsilon _{t-2}+\sqrt{1-\alpha _{t}}\epsilon _{t-1} \\ &=\sqrt{\alpha _{t}\alpha _{t-1}}x_{t-2}+\sqrt{1-\alpha _{t}\alpha _{t-1}}\epsilon\bar{}_{t-2}\\ &......\\ &=\sqrt{ \bar{\alpha_{t}}}x_{0}+\sqrt{1-\bar{\alpha _{t}}}\epsilon \end{aligned}

需要注意的是,当两个高斯分布相加时,满足如下规律:

Z = X + Y

Z \sim \mathcal N (\mu_{X}+\mu_{Y} ,\sigma_{X}^{2}+\sigma_{Y}^{2})

因此第四行公式可以直接转换为第五行公式。

所以现在我们直接把各个step的\bar{\alpha _{t}}算出来就可以了,不用再每个step进行迭代。

在反向噪声扩散中,由于每次加的噪声很小,所以q(x_{t-1}|x_{t})也可以视为高斯分布,使用神经网络UNet进行拟合。这里推导公式比较复杂,可以参考原论文2006.11239.pdf (arxiv.org)

最后,通过KL散度来让正向分布p和反向分布q尽可能接近。训练和采样流程如下:


3.代码

接下来我们用pytorch来实现Diffusion在MNIST数据集上的生成。

3.1模型

Unet中上采样和下采样模块都基于resblock,同时还有对step进行embedding的全连接层。数据进行下采样之后,再使上采样输出与step embeding向量进行相加,再输入进下一层上采样层中。

  1. class ResidualConvBlock(nn.Module):
  2. def __init__(
  3. self, in_channels: int, out_channels: int, is_res: bool = False
  4. ) -> None:
  5. super().__init__()
  6. '''
  7. standard ResNet style convolutional block
  8. '''
  9. self.same_channels = in_channels==out_channels
  10. self.is_res = is_res
  11. self.conv1 = nn.Sequential(
  12. nn.Conv2d(in_channels, out_channels, 3, 1, 1),
  13. nn.BatchNorm2d(out_channels),
  14. nn.GELU(),
  15. )
  16. self.conv2 = nn.Sequential(
  17. nn.Conv2d(out_channels, out_channels, 3, 1, 1),
  18. nn.BatchNorm2d(out_channels),
  19. nn.GELU(),
  20. )
  21. def forward(self, x: torch.Tensor) -> torch.Tensor:
  22. if self.is_res:
  23. x1 = self.conv1(x)
  24. x2 = self.conv2(x1)
  25. # this adds on correct residual in case channels have increased
  26. if self.same_channels:
  27. out = x + x2
  28. else:
  29. out = x1 + x2
  30. return out / 1.414
  31. else:
  32. x1 = self.conv1(x)
  33. x2 = self.conv2(x1)
  34. return x2
  35. class UnetDown(nn.Module):
  36. def __init__(self, in_channels, out_channels):
  37. super(UnetDown, self).__init__()
  38. '''
  39. process and downscale the image feature maps
  40. '''
  41. layers = [ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2)]
  42. self.model = nn.Sequential(*layers)
  43. def forward(self, x):
  44. return self.model(x)
  45. class UnetUp(nn.Module):
  46. def __init__(self, in_channels, out_channels):
  47. super(UnetUp, self).__init__()
  48. '''
  49. process and upscale the image feature maps
  50. '''
  51. layers = [
  52. nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
  53. ResidualConvBlock(out_channels, out_channels),
  54. ResidualConvBlock(out_channels, out_channels),
  55. ]
  56. self.model = nn.Sequential(*layers)
  57. def forward(self, x, skip):
  58. x = torch.cat((x, skip), 1)
  59. x = self.model(x)
  60. return x
  61. class EmbedFC(nn.Module):
  62. def __init__(self, input_dim, emb_dim):
  63. super(EmbedFC, self).__init__()
  64. '''
  65. generic one layer FC NN for embedding things
  66. '''
  67. self.input_dim = input_dim
  68. layers = [
  69. nn.Linear(input_dim, emb_dim),
  70. nn.GELU(),
  71. nn.Linear(emb_dim, emb_dim),
  72. ]
  73. self.model = nn.Sequential(*layers)
  74. def forward(self, x):
  75. x = x.view(-1, self.input_dim)
  76. return self.model(x)
  77. class Unet(nn.Module):
  78. def __init__(self, in_channels, n_feat=256):
  79. super(Unet, self).__init__()
  80. self.in_channels = in_channels
  81. self.n_feat = n_feat
  82. self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
  83. self.down1 = UnetDown(n_feat, n_feat)
  84. self.down2 = UnetDown(n_feat, 2 * n_feat)
  85. self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
  86. self.timeembed1 = EmbedFC(1, 2 * n_feat)
  87. self.timeembed2 = EmbedFC(1, 1 * n_feat)
  88. self.up0 = nn.Sequential(
  89. # nn.ConvTranspose2d(6 * n_feat, 2 * n_feat, 7, 7), # when concat temb and cemb end up w 6*n_feat
  90. nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 7, 7), # otherwise just have 2*n_feat
  91. nn.GroupNorm(8, 2 * n_feat),
  92. nn.ReLU(),
  93. )
  94. self.up1 = UnetUp(4 * n_feat, n_feat)
  95. self.up2 = UnetUp(2 * n_feat, n_feat)
  96. self.out = nn.Sequential(
  97. nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),
  98. nn.GroupNorm(8, n_feat),
  99. nn.ReLU(),
  100. nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
  101. )
  102. def forward(self, x, t):
  103. '''
  104. 输入加噪图像和对应的时间step,预测反向噪声的正态分布
  105. :param x: 加噪图像
  106. :param t: 对应step
  107. :return: 正态分布噪声
  108. '''
  109. x = self.init_conv(x)
  110. down1 = self.down1(x)
  111. down2 = self.down2(down1)
  112. hiddenvec = self.to_vec(down2)
  113. # embed time step
  114. temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
  115. temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
  116. # 将上采样输出与step编码相加,输入到下一个上采样层
  117. up1 = self.up0(hiddenvec)
  118. up2 = self.up1(up1 + temb1, down2)
  119. up3 = self.up2(up2 + temb2, down1)
  120. out = self.out(torch.cat((up3, x), 1))
  121. return out

3.2训练

训练时随机选择step和随机生成正态分布噪声,通过叠加后得到加噪图像,然后将加噪图像和step一起输入进Unet中,得到当前step的预测正态分布噪声,并与真实正态分布噪声计算loss。

  1. def forward(self, x):
  2. """
  3. 训练过程中, 随机选择step和生成噪声
  4. """
  5. # 随机选择step
  6. _ts = torch.randint(1, self.n_T + 1, (x.shape[0],)).to(self.device) # t ~ Uniform(0, n_T)
  7. # 随机生成正态分布噪声
  8. noise = torch.randn_like(x) # eps ~ N(0, 1)
  9. # 加噪后的图像x_t
  10. x_t = (
  11. self.sqrtab[_ts, None, None, None] * x
  12. + self.sqrtmab[_ts, None, None, None] * noise
  13. )
  14. # 将unet预测的对应step的正态分布噪声与真实噪声做对比
  15. return self.loss_mse(noise, self.model(x_t, _ts / self.n_T))

3.3推理&可视化

推理的时候从随机的初始噪声开始,预测当前噪声的上一个step的正态分布噪声,然后根据采样公式得到反向扩散的均值和方差,最后根据重整化公式计算出上一个step的图像。重复多个step后得到最终的去噪图像。

  1. def sample(self, n_sample, size, device):
  2. # 随机生成初始噪声图片 x_T ~ N(0, 1)
  3. x_i = torch.randn(n_sample, *size).to(device)
  4. for i in range(self.n_T, 0, -1):
  5. t_is = torch.tensor([i / self.n_T]).to(device)
  6. t_is = t_is.repeat(n_sample, 1, 1, 1)
  7. z = torch.randn(n_sample, *size).to(device) if i > 1 else 0
  8. eps = self.model(x_i, t_is)
  9. x_i = x_i[:n_sample]
  10. x_i = self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i]) + self.sqrt_beta_t[i] * z
  11. return x_i
  1. @torch.no_grad()
  2. def visualize_results(self, epoch):
  3. self.sampler.eval()
  4. # 保存结果路径
  5. output_path = 'results/Diffusion'
  6. if not os.path.exists(output_path):
  7. os.makedirs(output_path)
  8. tot_num_samples = self.sample_num
  9. image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
  10. out = self.sampler.sample(tot_num_samples, (1, 28, 28), self.device)
  11. save_image(out, os.path.join(output_path, '{}.jpg'.format(epoch)), nrow=image_frame_dim)

完整代码如下:

  1. import torch, time, os
  2. import numpy as np
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. from torchvision.datasets import MNIST
  6. from torchvision import transforms
  7. from torch.utils.data import DataLoader
  8. from torchvision.utils import save_image
  9. import torch.nn.functional as F
  10. class ResidualConvBlock(nn.Module):
  11. def __init__(
  12. self, in_channels: int, out_channels: int, is_res: bool = False
  13. ) -> None:
  14. super().__init__()
  15. '''
  16. standard ResNet style convolutional block
  17. '''
  18. self.same_channels = in_channels==out_channels
  19. self.is_res = is_res
  20. self.conv1 = nn.Sequential(
  21. nn.Conv2d(in_channels, out_channels, 3, 1, 1),
  22. nn.BatchNorm2d(out_channels),
  23. nn.GELU(),
  24. )
  25. self.conv2 = nn.Sequential(
  26. nn.Conv2d(out_channels, out_channels, 3, 1, 1),
  27. nn.BatchNorm2d(out_channels),
  28. nn.GELU(),
  29. )
  30. def forward(self, x: torch.Tensor) -> torch.Tensor:
  31. if self.is_res:
  32. x1 = self.conv1(x)
  33. x2 = self.conv2(x1)
  34. # this adds on correct residual in case channels have increased
  35. if self.same_channels:
  36. out = x + x2
  37. else:
  38. out = x1 + x2
  39. return out / 1.414
  40. else:
  41. x1 = self.conv1(x)
  42. x2 = self.conv2(x1)
  43. return x2
  44. class UnetDown(nn.Module):
  45. def __init__(self, in_channels, out_channels):
  46. super(UnetDown, self).__init__()
  47. '''
  48. process and downscale the image feature maps
  49. '''
  50. layers = [ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2)]
  51. self.model = nn.Sequential(*layers)
  52. def forward(self, x):
  53. return self.model(x)
  54. class UnetUp(nn.Module):
  55. def __init__(self, in_channels, out_channels):
  56. super(UnetUp, self).__init__()
  57. '''
  58. process and upscale the image feature maps
  59. '''
  60. layers = [
  61. nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
  62. ResidualConvBlock(out_channels, out_channels),
  63. ResidualConvBlock(out_channels, out_channels),
  64. ]
  65. self.model = nn.Sequential(*layers)
  66. def forward(self, x, skip):
  67. x = torch.cat((x, skip), 1)
  68. x = self.model(x)
  69. return x
  70. class EmbedFC(nn.Module):
  71. def __init__(self, input_dim, emb_dim):
  72. super(EmbedFC, self).__init__()
  73. '''
  74. generic one layer FC NN for embedding things
  75. '''
  76. self.input_dim = input_dim
  77. layers = [
  78. nn.Linear(input_dim, emb_dim),
  79. nn.GELU(),
  80. nn.Linear(emb_dim, emb_dim),
  81. ]
  82. self.model = nn.Sequential(*layers)
  83. def forward(self, x):
  84. x = x.view(-1, self.input_dim)
  85. return self.model(x)
  86. class Unet(nn.Module):
  87. def __init__(self, in_channels, n_feat=256):
  88. super(Unet, self).__init__()
  89. self.in_channels = in_channels
  90. self.n_feat = n_feat
  91. self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
  92. self.down1 = UnetDown(n_feat, n_feat)
  93. self.down2 = UnetDown(n_feat, 2 * n_feat)
  94. self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
  95. self.timeembed1 = EmbedFC(1, 2 * n_feat)
  96. self.timeembed2 = EmbedFC(1, 1 * n_feat)
  97. self.up0 = nn.Sequential(
  98. # nn.ConvTranspose2d(6 * n_feat, 2 * n_feat, 7, 7), # when concat temb and cemb end up w 6*n_feat
  99. nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 7, 7), # otherwise just have 2*n_feat
  100. nn.GroupNorm(8, 2 * n_feat),
  101. nn.ReLU(),
  102. )
  103. self.up1 = UnetUp(4 * n_feat, n_feat)
  104. self.up2 = UnetUp(2 * n_feat, n_feat)
  105. self.out = nn.Sequential(
  106. nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),
  107. nn.GroupNorm(8, n_feat),
  108. nn.ReLU(),
  109. nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
  110. )
  111. def forward(self, x, t):
  112. '''
  113. 输入加噪图像和对应的时间step,预测反向噪声的正态分布
  114. :param x: 加噪图像
  115. :param t: 对应step
  116. :return: 正态分布噪声
  117. '''
  118. x = self.init_conv(x)
  119. down1 = self.down1(x)
  120. down2 = self.down2(down1)
  121. hiddenvec = self.to_vec(down2)
  122. # embed time step
  123. temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
  124. temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
  125. # 将上采样输出与step编码相加,输入到下一个上采样层
  126. up1 = self.up0(hiddenvec)
  127. up2 = self.up1(up1 + temb1, down2)
  128. up3 = self.up2(up2 + temb2, down1)
  129. out = self.out(torch.cat((up3, x), 1))
  130. return out
  131. class DDPM(nn.Module):
  132. def __init__(self, model, betas, n_T, device):
  133. super(DDPM, self).__init__()
  134. self.model = model.to(device)
  135. # register_buffer 可以提前保存alpha相关,节约时间
  136. for k, v in self.ddpm_schedules(betas[0], betas[1], n_T).items():
  137. self.register_buffer(k, v)
  138. self.n_T = n_T
  139. self.device = device
  140. self.loss_mse = nn.MSELoss()
  141. def ddpm_schedules(self, beta1, beta2, T):
  142. '''
  143. 提前计算各个step的alpha,这里beta是线性变化
  144. :param beta1: beta的下限
  145. :param beta2: beta的下限
  146. :param T: 总共的step数
  147. '''
  148. assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"
  149. beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1 # 生成beta1-beta2均匀分布的数组
  150. sqrt_beta_t = torch.sqrt(beta_t)
  151. alpha_t = 1 - beta_t
  152. log_alpha_t = torch.log(alpha_t)
  153. alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp() # alpha累乘
  154. sqrtab = torch.sqrt(alphabar_t) # 根号alpha累乘
  155. oneover_sqrta = 1 / torch.sqrt(alpha_t) # 1 / 根号alpha
  156. sqrtmab = torch.sqrt(1 - alphabar_t) # 根号下(1-alpha累乘)
  157. mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab
  158. return {
  159. "alpha_t": alpha_t, # \alpha_t
  160. "oneover_sqrta": oneover_sqrta, # 1/\sqrt{\alpha_t}
  161. "sqrt_beta_t": sqrt_beta_t, # \sqrt{\beta_t}
  162. "alphabar_t": alphabar_t, # \bar{\alpha_t}
  163. "sqrtab": sqrtab, # \sqrt{\bar{\alpha_t}} # 加噪标准差
  164. "sqrtmab": sqrtmab, # \sqrt{1-\bar{\alpha_t}} # 加噪均值
  165. "mab_over_sqrtmab": mab_over_sqrtmab_inv, # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
  166. }
  167. def forward(self, x):
  168. """
  169. 训练过程中, 随机选择step和生成噪声
  170. """
  171. # 随机选择step
  172. _ts = torch.randint(1, self.n_T + 1, (x.shape[0],)).to(self.device) # t ~ Uniform(0, n_T)
  173. # 随机生成正态分布噪声
  174. noise = torch.randn_like(x) # eps ~ N(0, 1)
  175. # 加噪后的图像x_t
  176. x_t = (
  177. self.sqrtab[_ts, None, None, None] * x
  178. + self.sqrtmab[_ts, None, None, None] * noise
  179. )
  180. # 将unet预测的对应step的正态分布噪声与真实噪声做对比
  181. return self.loss_mse(noise, self.model(x_t, _ts / self.n_T))
  182. def sample(self, n_sample, size, device):
  183. # 随机生成初始噪声图片 x_T ~ N(0, 1)
  184. x_i = torch.randn(n_sample, *size).to(device)
  185. for i in range(self.n_T, 0, -1):
  186. t_is = torch.tensor([i / self.n_T]).to(device)
  187. t_is = t_is.repeat(n_sample, 1, 1, 1)
  188. z = torch.randn(n_sample, *size).to(device) if i > 1 else 0
  189. eps = self.model(x_i, t_is)
  190. x_i = x_i[:n_sample]
  191. x_i = self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i]) + self.sqrt_beta_t[i] * z
  192. return x_i
  193. class ImageGenerator(object):
  194. def __init__(self):
  195. '''
  196. 初始化,定义超参数、数据集、网络结构等
  197. '''
  198. self.epoch = 20
  199. self.sample_num = 100
  200. self.batch_size = 256
  201. self.lr = 0.0001
  202. self.n_T = 400
  203. self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
  204. self.init_dataloader()
  205. self.sampler = DDPM(model=Unet(in_channels=1), betas=(1e-4, 0.02), n_T=self.n_T, device=self.device).to(self.device)
  206. self.optimizer = optim.Adam(self.sampler.model.parameters(), lr=self.lr)
  207. def init_dataloader(self):
  208. '''
  209. 初始化数据集和dataloader
  210. '''
  211. tf = transforms.Compose([
  212. transforms.ToTensor(),
  213. ])
  214. train_dataset = MNIST('./data/',
  215. train=True,
  216. download=True,
  217. transform=tf)
  218. self.train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)
  219. val_dataset = MNIST('./data/',
  220. train=False,
  221. download=True,
  222. transform=tf)
  223. self.val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
  224. def train(self):
  225. self.sampler.train()
  226. print('训练开始!!')
  227. for epoch in range(self.epoch):
  228. self.sampler.model.train()
  229. loss_mean = 0
  230. for i, (images, labels) in enumerate(self.train_dataloader):
  231. images, labels = images.to(self.device), labels.to(self.device)
  232. # 将latent和condition拼接后输入网络
  233. loss = self.sampler(images)
  234. loss_mean += loss.item()
  235. self.optimizer.zero_grad()
  236. loss.backward()
  237. self.optimizer.step()
  238. train_loss = loss_mean / len(self.train_dataloader)
  239. print('epoch:{}, loss:{:.4f}'.format(epoch, train_loss))
  240. self.visualize_results(epoch)
  241. @torch.no_grad()
  242. def visualize_results(self, epoch):
  243. self.sampler.eval()
  244. # 保存结果路径
  245. output_path = 'results/Diffusion'
  246. if not os.path.exists(output_path):
  247. os.makedirs(output_path)
  248. tot_num_samples = self.sample_num
  249. image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
  250. out = self.sampler.sample(tot_num_samples, (1, 28, 28), self.device)
  251. save_image(out, os.path.join(output_path, '{}.jpg'.format(epoch)), nrow=image_frame_dim)
  252. if __name__ == '__main__':
  253. generator = ImageGenerator()
  254. generator.train()

4. condition代码及结果

如果我们要生成condition条件下的图像,我们需要对condition向量进行embedding后再拼接到unet输入中。

  1. import torch, time, os
  2. import numpy as np
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. from torchvision.datasets import MNIST
  6. from torchvision import transforms
  7. from torch.utils.data import DataLoader
  8. from torchvision.utils import save_image
  9. import torch.nn.functional as F
  10. class ResidualConvBlock(nn.Module):
  11. def __init__(
  12. self, in_channels: int, out_channels: int, is_res: bool = False
  13. ) -> None:
  14. super().__init__()
  15. '''
  16. standard ResNet style convolutional block
  17. '''
  18. self.same_channels = in_channels==out_channels
  19. self.is_res = is_res
  20. self.conv1 = nn.Sequential(
  21. nn.Conv2d(in_channels, out_channels, 3, 1, 1),
  22. nn.BatchNorm2d(out_channels),
  23. nn.GELU(),
  24. )
  25. self.conv2 = nn.Sequential(
  26. nn.Conv2d(out_channels, out_channels, 3, 1, 1),
  27. nn.BatchNorm2d(out_channels),
  28. nn.GELU(),
  29. )
  30. def forward(self, x: torch.Tensor) -> torch.Tensor:
  31. if self.is_res:
  32. x1 = self.conv1(x)
  33. x2 = self.conv2(x1)
  34. # this adds on correct residual in case channels have increased
  35. if self.same_channels:
  36. out = x + x2
  37. else:
  38. out = x1 + x2
  39. return out / 1.414
  40. else:
  41. x1 = self.conv1(x)
  42. x2 = self.conv2(x1)
  43. return x2
  44. class UnetDown(nn.Module):
  45. def __init__(self, in_channels, out_channels):
  46. super(UnetDown, self).__init__()
  47. '''
  48. process and downscale the image feature maps
  49. '''
  50. layers = [ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2)]
  51. self.model = nn.Sequential(*layers)
  52. def forward(self, x):
  53. return self.model(x)
  54. class UnetUp(nn.Module):
  55. def __init__(self, in_channels, out_channels):
  56. super(UnetUp, self).__init__()
  57. '''
  58. process and upscale the image feature maps
  59. '''
  60. layers = [
  61. nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
  62. ResidualConvBlock(out_channels, out_channels),
  63. ResidualConvBlock(out_channels, out_channels),
  64. ]
  65. self.model = nn.Sequential(*layers)
  66. def forward(self, x, skip):
  67. x = torch.cat((x, skip), 1)
  68. x = self.model(x)
  69. return x
  70. class EmbedFC(nn.Module):
  71. def __init__(self, input_dim, emb_dim):
  72. super(EmbedFC, self).__init__()
  73. '''
  74. generic one layer FC NN for embedding things
  75. '''
  76. self.input_dim = input_dim
  77. layers = [
  78. nn.Linear(input_dim, emb_dim),
  79. nn.GELU(),
  80. nn.Linear(emb_dim, emb_dim),
  81. ]
  82. self.model = nn.Sequential(*layers)
  83. def forward(self, x):
  84. x = x.view(-1, self.input_dim)
  85. return self.model(x)
  86. class Unet(nn.Module):
  87. def __init__(self, in_channels, n_feat=256, n_classes=10):
  88. super(Unet, self).__init__()
  89. self.in_channels = in_channels
  90. self.n_feat = n_feat
  91. self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
  92. self.down1 = UnetDown(n_feat, n_feat)
  93. self.down2 = UnetDown(n_feat, 2 * n_feat)
  94. self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
  95. self.timeembed1 = EmbedFC(1, 2 * n_feat)
  96. self.timeembed2 = EmbedFC(1, 1 * n_feat)
  97. self.conditionembed1 = EmbedFC(n_classes, 2 * n_feat)
  98. self.conditionembed2 = EmbedFC(n_classes, 1 * n_feat)
  99. self.up0 = nn.Sequential(
  100. # nn.ConvTranspose2d(6 * n_feat, 2 * n_feat, 7, 7), # when concat temb and cemb end up w 6*n_feat
  101. nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 7, 7), # otherwise just have 2*n_feat
  102. nn.GroupNorm(8, 2 * n_feat),
  103. nn.ReLU(),
  104. )
  105. self.up1 = UnetUp(4 * n_feat, n_feat)
  106. self.up2 = UnetUp(2 * n_feat, n_feat)
  107. self.out = nn.Sequential(
  108. nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),
  109. nn.GroupNorm(8, n_feat),
  110. nn.ReLU(),
  111. nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
  112. )
  113. def forward(self, x, c, t):
  114. '''
  115. 输入加噪图像和对应的时间step,预测反向噪声的正态分布
  116. :param x: 加噪图像
  117. :param c: contition向量
  118. :param t: 对应step
  119. :return: 正态分布噪声
  120. '''
  121. x = self.init_conv(x)
  122. down1 = self.down1(x)
  123. down2 = self.down2(down1)
  124. hiddenvec = self.to_vec(down2)
  125. # embed time step
  126. temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
  127. temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
  128. cemb1 = self.conditionembed1(c).view(-1, self.n_feat * 2, 1, 1)
  129. cemb2 = self.conditionembed2(c).view(-1, self.n_feat, 1, 1)
  130. # 将上采样输出与step编码相加,输入到下一个上采样层
  131. up1 = self.up0(hiddenvec)
  132. up2 = self.up1(cemb1 * up1 + temb1, down2)
  133. up3 = self.up2(cemb2 * up2 + temb2, down1)
  134. out = self.out(torch.cat((up3, x), 1))
  135. return out
  136. class DDPM(nn.Module):
  137. def __init__(self, model, betas, n_T, device):
  138. super(DDPM, self).__init__()
  139. self.model = model.to(device)
  140. # register_buffer 可以提前保存alpha相关,节约时间
  141. for k, v in self.ddpm_schedules(betas[0], betas[1], n_T).items():
  142. self.register_buffer(k, v)
  143. self.n_T = n_T
  144. self.device = device
  145. self.loss_mse = nn.MSELoss()
  146. def ddpm_schedules(self, beta1, beta2, T):
  147. '''
  148. 提前计算各个step的alpha,这里beta是线性变化
  149. :param beta1: beta的下限
  150. :param beta2: beta的下限
  151. :param T: 总共的step数
  152. '''
  153. assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"
  154. beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1 # 生成beta1-beta2均匀分布的数组
  155. sqrt_beta_t = torch.sqrt(beta_t)
  156. alpha_t = 1 - beta_t
  157. log_alpha_t = torch.log(alpha_t)
  158. alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp() # alpha累乘
  159. sqrtab = torch.sqrt(alphabar_t) # 根号alpha累乘
  160. oneover_sqrta = 1 / torch.sqrt(alpha_t) # 1 / 根号alpha
  161. sqrtmab = torch.sqrt(1 - alphabar_t) # 根号下(1-alpha累乘)
  162. mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab
  163. return {
  164. "alpha_t": alpha_t, # \alpha_t
  165. "oneover_sqrta": oneover_sqrta, # 1/\sqrt{\alpha_t}
  166. "sqrt_beta_t": sqrt_beta_t, # \sqrt{\beta_t}
  167. "alphabar_t": alphabar_t, # \bar{\alpha_t}
  168. "sqrtab": sqrtab, # \sqrt{\bar{\alpha_t}} # 加噪标准差
  169. "sqrtmab": sqrtmab, # \sqrt{1-\bar{\alpha_t}} # 加噪均值
  170. "mab_over_sqrtmab": mab_over_sqrtmab_inv, # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
  171. }
  172. def forward(self, x, c):
  173. """
  174. 训练过程中, 随机选择step和生成噪声
  175. """
  176. # 随机选择step
  177. _ts = torch.randint(1, self.n_T + 1, (x.shape[0],)).to(self.device) # t ~ Uniform(0, n_T)
  178. # 随机生成正态分布噪声
  179. noise = torch.randn_like(x) # eps ~ N(0, 1)
  180. # 加噪后的图像x_t
  181. x_t = (
  182. self.sqrtab[_ts, None, None, None] * x
  183. + self.sqrtmab[_ts, None, None, None] * noise
  184. )
  185. # 将unet预测的对应step的正态分布噪声与真实噪声做对比
  186. return self.loss_mse(noise, self.model(x_t, c, _ts / self.n_T))
  187. def sample(self, n_sample, c, size, device):
  188. # 随机生成初始噪声图片 x_T ~ N(0, 1)
  189. x_i = torch.randn(n_sample, *size).to(device)
  190. for i in range(self.n_T, 0, -1):
  191. t_is = torch.tensor([i / self.n_T]).to(device)
  192. t_is = t_is.repeat(n_sample, 1, 1, 1)
  193. z = torch.randn(n_sample, *size).to(device) if i > 1 else 0
  194. eps = self.model(x_i, c, t_is)
  195. x_i = x_i[:n_sample]
  196. x_i = self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i]) + self.sqrt_beta_t[i] * z
  197. return x_i
  198. class ImageGenerator(object):
  199. def __init__(self):
  200. '''
  201. 初始化,定义超参数、数据集、网络结构等
  202. '''
  203. self.epoch = 20
  204. self.sample_num = 100
  205. self.batch_size = 256
  206. self.lr = 0.0001
  207. self.n_T = 400
  208. self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
  209. self.init_dataloader()
  210. self.sampler = DDPM(model=Unet(in_channels=1), betas=(1e-4, 0.02), n_T=self.n_T, device=self.device).to(self.device)
  211. self.optimizer = optim.Adam(self.sampler.model.parameters(), lr=self.lr)
  212. def init_dataloader(self):
  213. '''
  214. 初始化数据集和dataloader
  215. '''
  216. tf = transforms.Compose([
  217. transforms.ToTensor(),
  218. ])
  219. train_dataset = MNIST('./data/',
  220. train=True,
  221. download=True,
  222. transform=tf)
  223. self.train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)
  224. val_dataset = MNIST('./data/',
  225. train=False,
  226. download=True,
  227. transform=tf)
  228. self.val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
  229. def train(self):
  230. self.sampler.train()
  231. print('训练开始!!')
  232. for epoch in range(self.epoch):
  233. self.sampler.model.train()
  234. loss_mean = 0
  235. for i, (images, labels) in enumerate(self.train_dataloader):
  236. images, labels = images.to(self.device), labels.to(self.device)
  237. labels = F.one_hot(labels, num_classes=10).float()
  238. # 将latent和condition拼接后输入网络
  239. loss = self.sampler(images, labels)
  240. loss_mean += loss.item()
  241. self.optimizer.zero_grad()
  242. loss.backward()
  243. self.optimizer.step()
  244. train_loss = loss_mean / len(self.train_dataloader)
  245. print('epoch:{}, loss:{:.4f}'.format(epoch, train_loss))
  246. self.visualize_results(epoch)
  247. @torch.no_grad()
  248. def visualize_results(self, epoch):
  249. self.sampler.eval()
  250. # 保存结果路径
  251. output_path = 'results/Diffusion'
  252. if not os.path.exists(output_path):
  253. os.makedirs(output_path)
  254. tot_num_samples = self.sample_num
  255. image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
  256. labels = F.one_hot(torch.Tensor(np.repeat(np.arange(10), 10)).to(torch.int64), num_classes=10).to(self.device).float()
  257. out = self.sampler.sample(tot_num_samples, labels, (1, 28, 28), self.device)
  258. save_image(out, os.path.join(output_path, '{}.jpg'.format(epoch)), nrow=image_frame_dim)
  259. if __name__ == '__main__':
  260. generator = ImageGenerator()
  261. generator.train()

但是如果我们只用condition条件,网络可能过拟合后就无法生成非condition条件下的图像。为了同时满足condition和非condition生成,可以采用classifier free guide的方法,即将condition和非condition同时输入进网络同时训练。代码后续有机会补上~


业务合作/学习交流+v:lizhiTechnology

  如果想要了解更多图像生成相关知识,可以参考我的专栏和其他相关文章:

图像生成_Lcm_Tech的博客-CSDN博客

【图像生成】(一) DNN 原理 & pytorch代码实例_pytorch dnn代码-CSDN博客

【图像生成】(二) GAN 原理 & pytorch代码实例_gan代码-CSDN博客

【图像生成】(三) VAE原理 & pytorch代码实例_vae算法 是如何生成图的-CSDN博客

【图像生成】(四) Diffusion原理 & pytorch代码实例_diffusion unet-CSDN博客

如果想要了解更多深度学习相关知识,可以参考我的其他文章:

深度学习_Lcm_Tech的博客-CSDN博客

【优化器】(一) SGD原理 & pytorch代码解析_sgd优化器-CSDN博客

【损失函数】(一) L1Loss原理 & pytorch代码解析_l1 loss-CSDN博客

【diffusers】(一) diffusers库介绍 & 框架代码解析-CSDN博客

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/225766
推荐阅读
相关标签
  

闽ICP备14008679号