当前位置:   article > 正文

深入浅出 diffusion(4):pytorch 实现简单 diffusion

深入浅出 diffusion(4):pytorch 实现简单 diffusion

 1. 训练和采样流程

 2. 无条件实现

  1. import torch, time, os
  2. import numpy as np
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. from torchvision.datasets import MNIST
  6. from torchvision import transforms
  7. from torch.utils.data import DataLoader
  8. from torchvision.utils import save_image
  9. import torch.nn.functional as F
  10. class ResidualConvBlock(nn.Module):
  11. def __init__(
  12. self, in_channels: int, out_channels: int, is_res: bool = False
  13. ) -> None:
  14. super().__init__()
  15. '''
  16. standard ResNet style convolutional block
  17. '''
  18. self.same_channels = in_channels==out_channels
  19. self.is_res = is_res
  20. self.conv1 = nn.Sequential(
  21. nn.Conv2d(in_channels, out_channels, 3, 1, 1),
  22. nn.BatchNorm2d(out_channels),
  23. nn.GELU(),
  24. )
  25. self.conv2 = nn.Sequential(
  26. nn.Conv2d(out_channels, out_channels, 3, 1, 1),
  27. nn.BatchNorm2d(out_channels),
  28. nn.GELU(),
  29. )
  30. def forward(self, x: torch.Tensor) -> torch.Tensor:
  31. if self.is_res:
  32. x1 = self.conv1(x)
  33. x2 = self.conv2(x1)
  34. # this adds on correct residual in case channels have increased
  35. if self.same_channels:
  36. out = x + x2
  37. else:
  38. out = x1 + x2
  39. return out / 1.414
  40. else:
  41. x1 = self.conv1(x)
  42. x2 = self.conv2(x1)
  43. return x2
  44. class UnetDown(nn.Module):
  45. def __init__(self, in_channels, out_channels):
  46. super(UnetDown, self).__init__()
  47. '''
  48. process and downscale the image feature maps
  49. '''
  50. layers = [ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2)]
  51. self.model = nn.Sequential(*layers)
  52. def forward(self, x):
  53. return self.model(x)
  54. class UnetUp(nn.Module):
  55. def __init__(self, in_channels, out_channels):
  56. super(UnetUp, self).__init__()
  57. '''
  58. process and upscale the image feature maps
  59. '''
  60. layers = [
  61. nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
  62. ResidualConvBlock(out_channels, out_channels),
  63. ResidualConvBlock(out_channels, out_channels),
  64. ]
  65. self.model = nn.Sequential(*layers)
  66. def forward(self, x, skip):
  67. x = torch.cat((x, skip), 1)
  68. x = self.model(x)
  69. return x
  70. class EmbedFC(nn.Module):
  71. def __init__(self, input_dim, emb_dim):
  72. super(EmbedFC, self).__init__()
  73. '''
  74. generic one layer FC NN for embedding things
  75. '''
  76. self.input_dim = input_dim
  77. layers = [
  78. nn.Linear(input_dim, emb_dim),
  79. nn.GELU(),
  80. nn.Linear(emb_dim, emb_dim),
  81. ]
  82. self.model = nn.Sequential(*layers)
  83. def forward(self, x):
  84. x = x.view(-1, self.input_dim)
  85. return self.model(x)
  86. class Unet(nn.Module):
  87. def __init__(self, in_channels, n_feat=256):
  88. super(Unet, self).__init__()
  89. self.in_channels = in_channels
  90. self.n_feat = n_feat
  91. self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
  92. self.down1 = UnetDown(n_feat, n_feat)
  93. self.down2 = UnetDown(n_feat, 2 * n_feat)
  94. self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
  95. self.timeembed1 = EmbedFC(1, 2 * n_feat)
  96. self.timeembed2 = EmbedFC(1, 1 * n_feat)
  97. self.up0 = nn.Sequential(
  98. # nn.ConvTranspose2d(6 * n_feat, 2 * n_feat, 7, 7), # when concat temb and cemb end up w 6*n_feat
  99. nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 7, 7), # otherwise just have 2*n_feat
  100. nn.GroupNorm(8, 2 * n_feat),
  101. nn.ReLU(),
  102. )
  103. self.up1 = UnetUp(4 * n_feat, n_feat)
  104. self.up2 = UnetUp(2 * n_feat, n_feat)
  105. self.out = nn.Sequential(
  106. nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),
  107. nn.GroupNorm(8, n_feat),
  108. nn.ReLU(),
  109. nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
  110. )
  111. def forward(self, x, t):
  112. '''
  113. 输入加噪图像和对应的时间step,预测反向噪声的正态分布
  114. :param x: 加噪图像
  115. :param t: 对应step
  116. :return: 正态分布噪声
  117. '''
  118. x = self.init_conv(x)
  119. down1 = self.down1(x)
  120. down2 = self.down2(down1)
  121. hiddenvec = self.to_vec(down2)
  122. # embed time step
  123. temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
  124. temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
  125. # 将上采样输出与step编码相加,输入到下一个上采样层
  126. up1 = self.up0(hiddenvec)
  127. up2 = self.up1(up1 + temb1, down2)
  128. up3 = self.up2(up2 + temb2, down1)
  129. out = self.out(torch.cat((up3, x), 1))
  130. return out
  131. class DDPM(nn.Module):
  132. def __init__(self, model, betas, n_T, device):
  133. super(DDPM, self).__init__()
  134. self.model = model.to(device)
  135. # register_buffer 可以提前保存alpha相关,节约时间
  136. for k, v in self.ddpm_schedules(betas[0], betas[1], n_T).items():
  137. self.register_buffer(k, v)
  138. self.n_T = n_T
  139. self.device = device
  140. self.loss_mse = nn.MSELoss()
  141. def ddpm_schedules(self, beta1, beta2, T):
  142. '''
  143. 提前计算各个step的alpha,这里beta是线性变化
  144. :param beta1: beta的下限
  145. :param beta2: beta的下限
  146. :param T: 总共的step数
  147. '''
  148. assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"
  149. beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1 # 生成beta1-beta2均匀分布的数组
  150. sqrt_beta_t = torch.sqrt(beta_t)
  151. alpha_t = 1 - beta_t
  152. log_alpha_t = torch.log(alpha_t)
  153. alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp() # alpha累乘
  154. sqrtab = torch.sqrt(alphabar_t) # 根号alpha累乘
  155. oneover_sqrta = 1 / torch.sqrt(alpha_t) # 1 / 根号alpha
  156. sqrtmab = torch.sqrt(1 - alphabar_t) # 根号下(1-alpha累乘)
  157. mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab
  158. return {
  159. "alpha_t": alpha_t, # \alpha_t
  160. "oneover_sqrta": oneover_sqrta, # 1/\sqrt{\alpha_t}
  161. "sqrt_beta_t": sqrt_beta_t, # \sqrt{\beta_t}
  162. "alphabar_t": alphabar_t, # \bar{\alpha_t}
  163. "sqrtab": sqrtab, # \sqrt{\bar{\alpha_t}} # 加噪标准差
  164. "sqrtmab": sqrtmab, # \sqrt{1-\bar{\alpha_t}} # 加噪均值
  165. "mab_over_sqrtmab": mab_over_sqrtmab_inv, # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
  166. }
  167. def forward(self, x):
  168. """
  169. 训练过程中, 随机选择step和生成噪声
  170. """
  171. # 随机选择step
  172. _ts = torch.randint(1, self.n_T + 1, (x.shape[0],)).to(self.device) # t ~ Uniform(0, n_T)
  173. # 随机生成正态分布噪声
  174. noise = torch.randn_like(x) # eps ~ N(0, 1)
  175. # 加噪后的图像x_t
  176. x_t = (
  177. self.sqrtab[_ts, None, None, None] * x
  178. + self.sqrtmab[_ts, None, None, None] * noise
  179. )
  180. # 将unet预测的对应step的正态分布噪声与真实噪声做对比
  181. return self.loss_mse(noise, self.model(x_t, _ts / self.n_T))
  182. def sample(self, n_sample, size, device):
  183. # 随机生成初始噪声图片 x_T ~ N(0, 1)
  184. x_i = torch.randn(n_sample, *size).to(device)
  185. for i in range(self.n_T, 0, -1):
  186. t_is = torch.tensor([i / self.n_T]).to(device)
  187. t_is = t_is.repeat(n_sample, 1, 1, 1)
  188. z = torch.randn(n_sample, *size).to(device) if i > 1 else 0
  189. eps = self.model(x_i, t_is)
  190. x_i = x_i[:n_sample]
  191. x_i = self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i]) + self.sqrt_beta_t[i] * z
  192. return x_i
  193. class ImageGenerator(object):
  194. def __init__(self):
  195. '''
  196. 初始化,定义超参数、数据集、网络结构等
  197. '''
  198. self.epoch = 20
  199. self.sample_num = 100
  200. self.batch_size = 256
  201. self.lr = 0.0001
  202. self.n_T = 400
  203. self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
  204. self.init_dataloader()
  205. self.sampler = DDPM(model=Unet(in_channels=1), betas=(1e-4, 0.02), n_T=self.n_T, device=self.device).to(self.device)
  206. self.optimizer = optim.Adam(self.sampler.model.parameters(), lr=self.lr)
  207. def init_dataloader(self):
  208. '''
  209. 初始化数据集和dataloader
  210. '''
  211. tf = transforms.Compose([
  212. transforms.ToTensor(),
  213. ])
  214. train_dataset = MNIST('./data/',
  215. train=True,
  216. download=True,
  217. transform=tf)
  218. self.train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)
  219. val_dataset = MNIST('./data/',
  220. train=False,
  221. download=True,
  222. transform=tf)
  223. self.val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
  224. def train(self):
  225. self.sampler.train()
  226. print('训练开始!!')
  227. for epoch in range(self.epoch):
  228. self.sampler.model.train()
  229. loss_mean = 0
  230. for i, (images, labels) in enumerate(self.train_dataloader):
  231. images, labels = images.to(self.device), labels.to(self.device)
  232. # 将latent和condition拼接后输入网络
  233. loss = self.sampler(images)
  234. loss_mean += loss.item()
  235. self.optimizer.zero_grad()
  236. loss.backward()
  237. self.optimizer.step()
  238. train_loss = loss_mean / len(self.train_dataloader)
  239. print('epoch:{}, loss:{:.4f}'.format(epoch, train_loss))
  240. self.visualize_results(epoch)
  241. @torch.no_grad()
  242. def visualize_results(self, epoch):
  243. self.sampler.eval()
  244. # 保存结果路径
  245. output_path = 'results/Diffusion'
  246. if not os.path.exists(output_path):
  247. os.makedirs(output_path)
  248. tot_num_samples = self.sample_num
  249. image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
  250. out = self.sampler.sample(tot_num_samples, (1, 28, 28), self.device)
  251. save_image(out, os.path.join(output_path, '{}.jpg'.format(epoch)), nrow=image_frame_dim)
  252. if __name__ == '__main__':
  253. generator = ImageGenerator()
  254. generator.train()

3. 有条件实现

  1. import torch, time, os
  2. import numpy as np
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. from torchvision.datasets import MNIST
  6. from torchvision import transforms
  7. from torch.utils.data import DataLoader
  8. from torchvision.utils import save_image
  9. import torch.nn.functional as F
  10. class ResidualConvBlock(nn.Module):
  11. def __init__(
  12. self, in_channels: int, out_channels: int, is_res: bool = False
  13. ) -> None:
  14. super().__init__()
  15. '''
  16. standard ResNet style convolutional block
  17. '''
  18. self.same_channels = in_channels==out_channels
  19. self.is_res = is_res
  20. self.conv1 = nn.Sequential(
  21. nn.Conv2d(in_channels, out_channels, 3, 1, 1),
  22. nn.BatchNorm2d(out_channels),
  23. nn.GELU(),
  24. )
  25. self.conv2 = nn.Sequential(
  26. nn.Conv2d(out_channels, out_channels, 3, 1, 1),
  27. nn.BatchNorm2d(out_channels),
  28. nn.GELU(),
  29. )
  30. def forward(self, x: torch.Tensor) -> torch.Tensor:
  31. if self.is_res:
  32. x1 = self.conv1(x)
  33. x2 = self.conv2(x1)
  34. # this adds on correct residual in case channels have increased
  35. if self.same_channels:
  36. out = x + x2
  37. else:
  38. out = x1 + x2
  39. return out / 1.414
  40. else:
  41. x1 = self.conv1(x)
  42. x2 = self.conv2(x1)
  43. return x2
  44. class UnetDown(nn.Module):
  45. def __init__(self, in_channels, out_channels):
  46. super(UnetDown, self).__init__()
  47. '''
  48. process and downscale the image feature maps
  49. '''
  50. layers = [ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2)]
  51. self.model = nn.Sequential(*layers)
  52. def forward(self, x):
  53. return self.model(x)
  54. class UnetUp(nn.Module):
  55. def __init__(self, in_channels, out_channels):
  56. super(UnetUp, self).__init__()
  57. '''
  58. process and upscale the image feature maps
  59. '''
  60. layers = [
  61. nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
  62. ResidualConvBlock(out_channels, out_channels),
  63. ResidualConvBlock(out_channels, out_channels),
  64. ]
  65. self.model = nn.Sequential(*layers)
  66. def forward(self, x, skip):
  67. x = torch.cat((x, skip), 1)
  68. x = self.model(x)
  69. return x
  70. class EmbedFC(nn.Module):
  71. def __init__(self, input_dim, emb_dim):
  72. super(EmbedFC, self).__init__()
  73. '''
  74. generic one layer FC NN for embedding things
  75. '''
  76. self.input_dim = input_dim
  77. layers = [
  78. nn.Linear(input_dim, emb_dim),
  79. nn.GELU(),
  80. nn.Linear(emb_dim, emb_dim),
  81. ]
  82. self.model = nn.Sequential(*layers)
  83. def forward(self, x):
  84. x = x.view(-1, self.input_dim)
  85. return self.model(x)
  86. class Unet(nn.Module):
  87. def __init__(self, in_channels, n_feat=256, n_classes=10):
  88. super(Unet, self).__init__()
  89. self.in_channels = in_channels
  90. self.n_feat = n_feat
  91. self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
  92. self.down1 = UnetDown(n_feat, n_feat)
  93. self.down2 = UnetDown(n_feat, 2 * n_feat)
  94. self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
  95. self.timeembed1 = EmbedFC(1, 2 * n_feat)
  96. self.timeembed2 = EmbedFC(1, 1 * n_feat)
  97. self.conditionembed1 = EmbedFC(n_classes, 2 * n_feat)
  98. self.conditionembed2 = EmbedFC(n_classes, 1 * n_feat)
  99. self.up0 = nn.Sequential(
  100. # nn.ConvTranspose2d(6 * n_feat, 2 * n_feat, 7, 7), # when concat temb and cemb end up w 6*n_feat
  101. nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 7, 7), # otherwise just have 2*n_feat
  102. nn.GroupNorm(8, 2 * n_feat),
  103. nn.ReLU(),
  104. )
  105. self.up1 = UnetUp(4 * n_feat, n_feat)
  106. self.up2 = UnetUp(2 * n_feat, n_feat)
  107. self.out = nn.Sequential(
  108. nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),
  109. nn.GroupNorm(8, n_feat),
  110. nn.ReLU(),
  111. nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
  112. )
  113. def forward(self, x, c, t):
  114. '''
  115. 输入加噪图像和对应的时间step,预测反向噪声的正态分布
  116. :param x: 加噪图像
  117. :param c: contition向量
  118. :param t: 对应step
  119. :return: 正态分布噪声
  120. '''
  121. x = self.init_conv(x)
  122. down1 = self.down1(x)
  123. down2 = self.down2(down1)
  124. hiddenvec = self.to_vec(down2)
  125. # embed time step
  126. temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
  127. temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
  128. cemb1 = self.conditionembed1(c).view(-1, self.n_feat * 2, 1, 1)
  129. cemb2 = self.conditionembed2(c).view(-1, self.n_feat, 1, 1)
  130. # 将上采样输出与step编码相加,输入到下一个上采样层
  131. up1 = self.up0(hiddenvec)
  132. up2 = self.up1(cemb1 * up1 + temb1, down2)
  133. up3 = self.up2(cemb2 * up2 + temb2, down1)
  134. out = self.out(torch.cat((up3, x), 1))
  135. return out
  136. class DDPM(nn.Module):
  137. def __init__(self, model, betas, n_T, device):
  138. super(DDPM, self).__init__()
  139. self.model = model.to(device)
  140. # register_buffer 可以提前保存alpha相关,节约时间
  141. for k, v in self.ddpm_schedules(betas[0], betas[1], n_T).items():
  142. self.register_buffer(k, v)
  143. self.n_T = n_T
  144. self.device = device
  145. self.loss_mse = nn.MSELoss()
  146. def ddpm_schedules(self, beta1, beta2, T):
  147. '''
  148. 提前计算各个step的alpha,这里beta是线性变化
  149. :param beta1: beta的下限
  150. :param beta2: beta的下限
  151. :param T: 总共的step数
  152. '''
  153. assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"
  154. beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1 # 生成beta1-beta2均匀分布的数组
  155. sqrt_beta_t = torch.sqrt(beta_t)
  156. alpha_t = 1 - beta_t
  157. log_alpha_t = torch.log(alpha_t)
  158. alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp() # alpha累乘
  159. sqrtab = torch.sqrt(alphabar_t) # 根号alpha累乘
  160. oneover_sqrta = 1 / torch.sqrt(alpha_t) # 1 / 根号alpha
  161. sqrtmab = torch.sqrt(1 - alphabar_t) # 根号下(1-alpha累乘)
  162. mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab
  163. return {
  164. "alpha_t": alpha_t, # \alpha_t
  165. "oneover_sqrta": oneover_sqrta, # 1/\sqrt{\alpha_t}
  166. "sqrt_beta_t": sqrt_beta_t, # \sqrt{\beta_t}
  167. "alphabar_t": alphabar_t, # \bar{\alpha_t}
  168. "sqrtab": sqrtab, # \sqrt{\bar{\alpha_t}} # 加噪标准差
  169. "sqrtmab": sqrtmab, # \sqrt{1-\bar{\alpha_t}} # 加噪均值
  170. "mab_over_sqrtmab": mab_over_sqrtmab_inv, # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
  171. }
  172. def forward(self, x, c):
  173. """
  174. 训练过程中, 随机选择step和生成噪声
  175. """
  176. # 随机选择step
  177. _ts = torch.randint(1, self.n_T + 1, (x.shape[0],)).to(self.device) # t ~ Uniform(0, n_T)
  178. # 随机生成正态分布噪声
  179. noise = torch.randn_like(x) # eps ~ N(0, 1)
  180. # 加噪后的图像x_t
  181. x_t = (
  182. self.sqrtab[_ts, None, None, None] * x
  183. + self.sqrtmab[_ts, None, None, None] * noise
  184. )
  185. # 将unet预测的对应step的正态分布噪声与真实噪声做对比
  186. return self.loss_mse(noise, self.model(x_t, c, _ts / self.n_T))
  187. def sample(self, n_sample, c, size, device):
  188. # 随机生成初始噪声图片 x_T ~ N(0, 1)
  189. x_i = torch.randn(n_sample, *size).to(device)
  190. for i in range(self.n_T, 0, -1):
  191. t_is = torch.tensor([i / self.n_T]).to(device)
  192. t_is = t_is.repeat(n_sample, 1, 1, 1)
  193. z = torch.randn(n_sample, *size).to(device) if i > 1 else 0
  194. eps = self.model(x_i, c, t_is)
  195. x_i = x_i[:n_sample]
  196. x_i = self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i]) + self.sqrt_beta_t[i] * z
  197. return x_i
  198. class ImageGenerator(object):
  199. def __init__(self):
  200. '''
  201. 初始化,定义超参数、数据集、网络结构等
  202. '''
  203. self.epoch = 20
  204. self.sample_num = 100
  205. self.batch_size = 256
  206. self.lr = 0.0001
  207. self.n_T = 400
  208. self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
  209. self.init_dataloader()
  210. self.sampler = DDPM(model=Unet(in_channels=1), betas=(1e-4, 0.02), n_T=self.n_T, device=self.device).to(self.device)
  211. self.optimizer = optim.Adam(self.sampler.model.parameters(), lr=self.lr)
  212. def init_dataloader(self):
  213. '''
  214. 初始化数据集和dataloader
  215. '''
  216. tf = transforms.Compose([
  217. transforms.ToTensor(),
  218. ])
  219. train_dataset = MNIST('./data/',
  220. train=True,
  221. download=True,
  222. transform=tf)
  223. self.train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)
  224. val_dataset = MNIST('./data/',
  225. train=False,
  226. download=True,
  227. transform=tf)
  228. self.val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
  229. def train(self):
  230. self.sampler.train()
  231. print('训练开始!!')
  232. for epoch in range(self.epoch):
  233. self.sampler.model.train()
  234. loss_mean = 0
  235. for i, (images, labels) in enumerate(self.train_dataloader):
  236. images, labels = images.to(self.device), labels.to(self.device)
  237. labels = F.one_hot(labels, num_classes=10).float()
  238. # 将latent和condition拼接后输入网络
  239. loss = self.sampler(images, labels)
  240. loss_mean += loss.item()
  241. self.optimizer.zero_grad()
  242. loss.backward()
  243. self.optimizer.step()
  244. train_loss = loss_mean / len(self.train_dataloader)
  245. print('epoch:{}, loss:{:.4f}'.format(epoch, train_loss))
  246. self.visualize_results(epoch)
  247. @torch.no_grad()
  248. def visualize_results(self, epoch):
  249. self.sampler.eval()
  250. # 保存结果路径
  251. output_path = 'results/Diffusion'
  252. if not os.path.exists(output_path):
  253. os.makedirs(output_path)
  254. tot_num_samples = self.sample_num
  255. image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
  256. labels = F.one_hot(torch.Tensor(np.repeat(np.arange(10), 10)).to(torch.int64), num_classes=10).to(self.device).float()
  257. out = self.sampler.sample(tot_num_samples, labels, (1, 28, 28), self.device)
  258. save_image(out, os.path.join(output_path, '{}.jpg'.format(epoch)), nrow=image_frame_dim)
  259. if __name__ == '__main__':
  260. generator = ImageGenerator()
  261. generator.train()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/47400?site
推荐阅读
相关标签
  

闽ICP备14008679号