当前位置:   article > 正文

昇思25天学习打卡营第13天|CycleGAN 图像风格迁移互换全流程解析

昇思25天学习打卡营第13天|CycleGAN 图像风格迁移互换全流程解析

     

目录

数据集下载和加载

可视化

构建生成器

构建判别器

优化器和损失函数

前向计算

计算梯度和反向传播

模型训练

模型推理


数据集下载和加载


        使用 download 接口下载数据集,并将下载后的数据集自动解压到当前目录下。数据下载之前需要使用 pip install download 安装 download 包。使用 MindSpore 的 MindDataset 接口读取和解析数据集。

        代码如下:

  1. %%capture captured_output  
  2. # 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号  
  3. !pip uninstall mindspore -y  
  4. !pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14  
  5. # 查看当前 mindspore 版本  
  6. !pip show mindspore  
  7. from download import download  
  8. url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip"  
  9. download(url, ".", kind="zip", replace=True)  
  10. from mindspore.dataset import MindDataset  
  11. # 读取MindRecord格式数据  
  12. name_mr = "./CycleGAN_apple2orange/apple2orange_train.mindrecord"  
  13. data = MindDataset(dataset_files=name_mr)  
  14. print("Datasize: ", data.get_dataset_size())  
  15. batch_size = 1  
  16. dataset = data.batch(batch_size)  
  17. datasize = dataset.get_dataset_size()  

        分析:首先,对已安装的 MindSpore 库进行卸载尝试。紧接着,通过特定的镜像源安装指定版本(2.2.14)的 MindSpore 库。最后,查看当下所安装的 MindSpore 版本。

        从指定的 URL 下载一个 zip 压缩文件至当前目录,同时指定若存在同名文件则予以替换。

        从 MindSpore 的 dataset 模块导入 MindDataset 类,之后指定一个名为 name_mr 的路径用于读取 MindRecord 格式的数据,并将其存储于 data 变量内,最终打印出此数据集的规模大小。

        设置批量大小为 1,对原始数据集进行批量处理,将处理后的数据集存于 dataset 变量中,最后获取处理后的数据集的大小并存储在 datasize 变量里。

可视化


        通过 create_dict_iterator 函数将数据转换成字典迭代器,然后使用 matplotlib 模块可视化部分训练数据。从 dataset 中读取前 5 组图像数据,对其进行处理,并以特定的布局在一个图形中展示这些图像。

        代码如下:

  1. import numpy as np  
  2. import matplotlib.pyplot as plt  
  3. mean = 0.5 * 255  
  4. std = 0.5 * 255  
  5. plt.figure(figsize=(125), dpi=60)  
  6. for i, data in enumerate(dataset.create_dict_iterator()):  
  7.     if i < 5:  
  8.         show_images_a = data["image_A"].asnumpy()  
  9.         show_images_b = data["image_B"].asnumpy()  
  10.   
  11.         plt.subplot(25, i+1)  
  12.         show_images_a = (show_images_a[0] * std + mean).astype(np.uint8).transpose((120))  
  13.         plt.imshow(show_images_a)  
  14.         plt.axis("off")  
  15.   
  16.         plt.subplot(25, i+6)  
  17.         show_images_b = (show_images_b[0] * std + mean).astype(np.uint8).transpose((120))  
  18.         plt.imshow(show_images_b)  
  19.         plt.axis("off")  
  20.     else:  
  21.         break  
  22. plt.show()  

        分析:首先,导入了 numpy 库并简称为 np ,同时引入了 matplotlib 库的 pyplot 模块以用于绘图。

        接着,定义了两个变量 mean 和 std ,它们将服务于后续的数据处理工作。

        之后,创建了一个新的图形,将其大小设定为 (12, 5) ,分辨率设为 60 。

        随后,运用 enumerate 函数对 dataset 的 create_dict_iterator 所生成的迭代器进行遍历。对于前面的 5 个数据项,如果索引 i 小于 5 ,就从数据中提取 image_A 和 image_B 并转换为 numpy 数组。而后对这些图像数据进行一系列处理,包含乘以 std 再加上 mean ,进行类型转换以及维度变换。使用 plt.subplot 在图形里创建子图,并于子图中展示处理后的图像,同时关闭坐标轴。一旦索引超过 4 ,就退出循环。

        最终,显示绘制好的图形。

        运行结果:

构建生成器


        首先,导入了 mindspore.nn 模块,并简称为 nn ;导入了 mindspore.ops 模块,并简称为 ops ;还从 mindspore.common.initializer 中引入了 Normal 以用于权重的初始化,并设定了一种权重初始化方式 weight_init 。

        定义了一个名为 ConvNormReLU 的类,该类继承自 nn.Cell 。此类别旨在构建涵盖卷积、归一化以及激活函数的层。在 __init__ 方法里,完成了各类参数的设定以及层的初始化操作,在 construct 方法中明确了前向传播的计算逻辑。

        定义了 ResidualBlock 类,用于构建残差块。它包含了两个 ConvNormReLU 层,并且能够选择是否运用 Dropout 。

        定义了 ResNetGenerator 类,这是一个基于残差网络的生成器。在 __init__ 方法中搭建了网络的各个层级,在 construct 方法中确定了前向传播的流程。

        再次强调,定义的 ResNetGenerator 类是一个基于残差网络的生成器,在 __init__ 方法中构建了网络的各层结构,在 construct 方法中定义了前向传播的流程。

        最后,实例化了两个 ResNetGenerator 对象,分别为 net_rg_a 和 net_rg_b ,并分别对它们的参数名称进行了更新。

        代码如下:

  1. import mindspore.nn as nn  
  2. import mindspore.ops as ops  
  3. from mindspore.common.initializer import Normal  
  4. weight_init = Normal(sigma=0.02)  
  5. class ConvNormReLU(nn.Cell):  
  6.     def __init__(self, input_channel, out_planes, kernel_size=4, stride=2, alpha=0.2, norm_mode='instance',  
  7.                  pad_mode='CONSTANT', use_relu=True, padding=None, transpose=False):  
  8.         super(ConvNormReLU, self).__init__()  
  9.         norm = nn.BatchNorm2d(out_planes)  
  10.         if norm_mode == 'instance':  
  11.             norm = nn.BatchNorm2d(out_planes, affine=False)  
  12.         has_bias = (norm_mode == 'instance')  
  13.         if padding is None:  
  14.             padding = (kernel_size - 1) // 2  
  15.         if pad_mode == 'CONSTANT':  
  16.             if transpose:  
  17.                 conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='same',  
  18.                                           has_bias=has_bias, weight_init=weight_init)  
  19.             else:  
  20.                 conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',  
  21.                                  has_bias=has_bias, padding=padding, weight_init=weight_init)  
  22.             layers = [conv, norm]  
  23.         else:  
  24.             paddings = ((00), (00), (padding, padding), (padding, padding))  
  25.             pad = nn.Pad(paddings=paddings, mode=pad_mode)  
  26.             if transpose:  
  27.                 conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='pad',  
  28.                                           has_bias=has_bias, weight_init=weight_init)  
  29.             else:  
  30.                 conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',  
  31.                                  has_bias=has_bias, weight_init=weight_init)  
  32.             layers = [pad, conv, norm]  
  33.         if use_relu:  
  34.             relu = nn.ReLU()  
  35.             if alpha > 0:  
  36.                 relu = nn.LeakyReLU(alpha)  
  37.             layers.append(relu)  
  38.         self.features = nn.SequentialCell(layers)  
  39.   
  40.     def construct(self, x):  
  41.         output = self.features(x)  
  42.         return output  
  43. class ResidualBlock(nn.Cell):  
  44.     def __init__(self, dim, norm_mode='instance', dropout=False, pad_mode="CONSTANT"):  
  45.         super(ResidualBlock, self).__init__()  
  46.         self.conv1 = ConvNormReLU(dim, dim, 310, norm_mode, pad_mode)  
  47.         self.conv2 = ConvNormReLU(dim, dim, 310, norm_mode, pad_mode, use_relu=False)  
  48.         self.dropout = dropout  
  49.         if dropout:  
  50.             self.dropout = nn.Dropout(p=0.5)  
  51.     def construct(self, x):  
  52.         out = self.conv1(x)  
  53.         if self.dropout:  
  54.             out = self.dropout(out)  
  55.         out = self.conv2(out)  
  56.         return x + out  
  57. class ResNetGenerator(nn.Cell):  
  58.     def __init__(self, input_channel=3, output_channel=64, n_layers=9, alpha=0.2, norm_mode='instance', dropout=False,  
  59.                  pad_mode="CONSTANT"):  
  60.         super(ResNetGenerator, self).__init__()  
  61.         self.conv_in = ConvNormReLU(input_channel, output_channel, 71, alpha, norm_mode, pad_mode=pad_mode)  
  62.         self.down_1 = ConvNormReLU(output_channel, output_channel * 232, alpha, norm_mode)  
  63.         self.down_2 = ConvNormReLU(output_channel * 2, output_channel * 432, alpha, norm_mode)  
  64.         layers = [ResidualBlock(output_channel * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * n_layers  
  65.         self.residuals = nn.SequentialCell(layers)  
  66.         self.up_2 = ConvNormReLU(output_channel * 4, output_channel * 232, alpha, norm_mode, transpose=True)  
  67.         self.up_1 = ConvNormReLU(output_channel * 2, output_channel, 32, alpha, norm_mode, transpose=True)  
  68.         if pad_mode == "CONSTANT":  
  69.             self.conv_out = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad',  
  70.                                       padding=3, weight_init=weight_init)  
  71.         else:  
  72.             pad = nn.Pad(paddings=((00), (00), (33), (33)), mode=pad_mode)  
  73.             conv = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad', weight_init=weight_init)  
  74.             self.conv_out = nn.SequentialCell([pad, conv])  
  75.   
  76.     def construct(self, x):  
  77.         x = self.conv_in(x)  
  78.         x = self.down_1(x)  
  79.         x = self.down_2(x)  
  80.         x = self.residuals(x)  
  81.         x = self.up_2(x)  
  82.         x = self.up_1(x)  
  83.         output = self.conv_out(x)  
  84.         return ops.tanh(output)  
  85. # 实例化生成器  
  86. net_rg_a = ResNetGenerator()  
  87. net_rg_a.update_parameters_name('net_rg_a.')  
  88. net_rg_b = ResNetGenerator()  
  89. net_rg_b.update_parameters_name('net_rg_b.')  

构建判别器


        判别器其实是一个二分类网络模型,输出判定该图像为真实图的概率。网络模型使用的是 Patch 大小为 70x70 的 PatchGANs 模型。通过一系列的 Conv2d 、 BatchNorm2d 和 LeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数得到最终概率。

        代码如下:

  1. # 定义判别器  
  2. class Discriminator(nn.Cell):  
  3.     def __init__(self, input_channel=3, output_channel=64, n_layers=3, alpha=0.2, norm_mode='instance'):  
  4.         super(Discriminator, self).__init__()  
  5.         kernel_size = 4  
  6.         layers = [nn.Conv2d(input_channel, output_channel, kernel_size, 2, pad_mode='pad', padding=1, weight_init=weight_init),  
  7.                   nn.LeakyReLU(alpha)]  
  8.         nf_mult = output_channel  
  9.         for i in range(1, n_layers):  
  10.             nf_mult_prev = nf_mult  
  11.             nf_mult = min(2 ** i, 8) * output_channel  
  12.             layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))  
  13.         nf_mult_prev = nf_mult  
  14.         nf_mult = min(2 ** n_layers, 8) * output_channel  
  15.         layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))  
  16.         layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1, weight_init=weight_init))  
  17.         self.features = nn.SequentialCell(layers)  
  18.   
  19.     def construct(self, x):  
  20.         output = self.features(x)  
  21.         return output  
  22.   
  23. # 判别器初始化  
  24. net_d_a = Discriminator()  
  25. net_d_a.update_parameters_name('net_d_a.')  
  26.   
  27. net_d_b = Discriminator()  
  28. net_d_b.update_parameters_name('net_d_b.'

        分析:定义了一个名为 Discriminator 的类,它继承自 nn.Cell 。在 __init__ 方法中,首先设置了一些参数,如输入通道数、输出通道数、层数等。然后初始化了一些层,包括卷积层和激活函数层,并通过一个循环逐步构建更多的卷积和归一化激活层。最后将这些层组合成一个顺序模型 nn.SequentialCell 并存储在 self.features 中。

        construct 方法定义了前向传播的计算逻辑,即输入数据通过 self.features 进行处理并返回输出。

        这部分代码实例化了两个 Discriminator 对象 net_d_a 和 net_d_b ,并分别更新了它们的参数名称。

优化器和损失函数


        创建了四个优化器对象。optimizer_rg_a 和 optimizer_rg_b 用于优化生成器 net_rg_a 和 net_rg_b 的可训练参数,optimizer_d_a 和 optimizer_d_b 用于优化判别器 net_d_a 和 net_d_b 的可训练参数。这里使用的优化算法是 Adam 算法,学习率均为 0.0002,beta1 值均为 0.5 。

        定义了两个损失函数,loss_fn 是均方误差损失函数(MSELoss),采用均值约简方式;l1_loss 是平均绝对误差损失函数(L1Loss)。

        定义了一个名为 gan_loss 的函数,用于计算 GAN 网络的损失。首先将目标值 target 扩展为与预测值 predict 形状相同且元素全为 1 乘以 target 的张量,然后使用之前定义的 loss_fn 计算预测值和扩展后的目标值之间的损失,并返回该损失值。

        代码如下:

  1. # 构建生成器,判别器优化器  
  2. optimizer_rg_a = nn.Adam(net_rg_a.trainable_params(), learning_rate=0.0002, beta1=0.5)  
  3. optimizer_rg_b = nn.Adam(net_rg_b.trainable_params(), learning_rate=0.0002, beta1=0.5)  
  4.   
  5. optimizer_d_a = nn.Adam(net_d_a.trainable_params(), learning_rate=0.0002, beta1=0.5)  
  6. optimizer_d_b = nn.Adam(net_d_b.trainable_params(), learning_rate=0.0002, beta1=0.5)  
  7.   
  8. # GAN网络损失函数,这里最后一层不使用sigmoid函数  
  9. loss_fn = nn.MSELoss(reduction='mean')  
  10. l1_loss = nn.L1Loss("mean")  
  11.   
  12. def gan_loss(predict, target):  
  13.     target = ops.ones_like(predict) * target  
  14.     loss = loss_fn(predict, target)  
  15.     return loss  

前向计算


        为了减少模型振荡[1],这里遵循 Shrivastava 等人的策略[2],使用生成器生成图像的历史数据而不是生成器生成的最新图像数据来更新鉴别器。这里创建 image_pool 函数,保留了一个图像缓冲区,用于存储生成器生成前的50个图像。搭建模型前向计算损失的过程,过程如下代码。

        代码如下:

  1. import mindspore as ms  
  2.   
  3. # 前向计算  
  4.   
  5. def generator(img_a, img_b):  
  6.     fake_a = net_rg_b(img_b)  
  7.     fake_b = net_rg_a(img_a)  
  8.     rec_a = net_rg_b(fake_b)  
  9.     rec_b = net_rg_a(fake_a)  
  10.     identity_a = net_rg_b(img_a)  
  11.     identity_b = net_rg_a(img_b)  
  12.     return fake_a, fake_b, rec_a, rec_b, identity_a, identity_b  
  13.   
  14. lambda_a = 10.0  
  15. lambda_b = 10.0  
  16. lambda_idt = 0.5  
  17.   
  18. def generator_forward(img_a, img_b):  
  19.     true = Tensor(True, dtype.bool_)  
  20.     fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)  
  21.     loss_g_a = gan_loss(net_d_b(fake_b), true)  
  22.     loss_g_b = gan_loss(net_d_a(fake_a), true)  
  23.     loss_c_a = l1_loss(rec_a, img_a) * lambda_a  
  24.     loss_c_b = l1_loss(rec_b, img_b) * lambda_b  
  25.     loss_idt_a = l1_loss(identity_a, img_a) * lambda_a * lambda_idt  
  26.     loss_idt_b = l1_loss(identity_b, img_b) * lambda_b * lambda_idt  
  27.     loss_g = loss_g_a + loss_g_b + loss_c_a + loss_c_b + loss_idt_a + loss_idt_b  
  28.     return fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_b  
  29.   
  30. def generator_forward_grad(img_a, img_b):  
  31.     _, _, loss_g, _, _, _, _, _, _ = generator_forward(img_a, img_b)  
  32.     return loss_g  
  33.   
  34. def discriminator_forward(img_a, img_b, fake_a, fake_b):  
  35.     false = Tensor(False, dtype.bool_)  
  36.     true = Tensor(True, dtype.bool_)  
  37.     d_fake_a = net_d_a(fake_a)  
  38.     d_img_a = net_d_a(img_a)  
  39.     d_fake_b = net_d_b(fake_b)  
  40.     d_img_b = net_d_b(img_b)  
  41.     loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)  
  42.     loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)  
  43.     loss_d = (loss_d_a + loss_d_b) * 0.5  
  44.     return loss_d  
  45.   
  46. def discriminator_forward_a(img_a, fake_a):  
  47.     false = Tensor(False, dtype.bool_)  
  48.     true = Tensor(True, dtype.bool_)  
  49.     d_fake_a = net_d_a(fake_a)  
  50.     d_img_a = net_d_a(img_a)  
  51.     loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)  
  52.     return loss_d_a  
  53.   
  54. def discriminator_forward_b(img_b, fake_b):  
  55.     false = Tensor(False, dtype.bool_)  
  56.     true = Tensor(True, dtype.bool_)  
  57.     d_fake_b = net_d_b(fake_b)  
  58.     d_img_b = net_d_b(img_b)  
  59.     loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)  
  60.     return loss_d_b  
  61.   
  62. # 保留了一个图像缓冲区,用来存储之前创建的50个图像  
  63. pool_size = 50  
  64. def image_pool(images):  
  65.     num_imgs = 0  
  66.     image1 = []  
  67.     if isinstance(images, Tensor):  
  68.         images = images.asnumpy()  
  69.     return_images = []  
  70.     for image in images:  
  71.         if num_imgs < pool_size:  
  72.             num_imgs = num_imgs + 1  
  73.             image1.append(image)  
  74.             return_images.append(image)  
  75.         else:  
  76.             if random.uniform(01) > 0.5:  
  77.                 random_id = random.randint(0, pool_size - 1)  
  78.   
  79.                 tmp = image1[random_id].copy()  
  80.                 image1[random_id] = image  
  81.                 return_images.append(tmp)  
  82.   
  83.             else:  
  84.                 return_images.append(image)  
  85.     output = Tensor(return_images, ms.float32)  
  86.     if output.ndim != 4:  
  87.         raise ValueError("img should be 4d, but get shape {}".format(output.shape))  
  88.     return output  

        分析:首先导入了 mindspore 库并简称为 ms 。

        定义了 generator 函数,用于执行生成器的前向计算。它接受两个图像输入 img_a 和 img_b ,通过生成器网络计算得到一系列的输出,包括生成的假图像、重建图像和恒等映射图像,并返回这些结果。

        定义了一些用于计算损失的权重系数。

        generator_forward 函数在 generator 函数的基础上计算生成器的各种损失,并将损失组合得到总的生成器损失 loss_g ,同时返回相关的输出和损失值。

        generator_forward_grad 函数获取 generator_forward 计算得到的生成器总损失 loss_g 并返回。

        discriminator_forward 函数计算判别器的损失。

        discriminator_forward 函数计算判别器的损失。

        discriminator_forward_b 函数计算与图像 img_b 和生成的假图像 fake_b 相关的判别器损失。

        image_pool 函数实现了一个图像缓冲区,用于存储一定数量的图像,并根据随机条件进行图像的替换和返回。

计算梯度和反向传播


        其中梯度计算也是分开不同的模型来进行的,详情见如下代码:

        代码如下:

  1. from mindspore import value_and_grad  
  2.   
  3. # 实例化求梯度的方法  
  4. grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params())  
  5. grad_g_b = value_and_grad(generator_forward_grad, None, net_rg_b.trainable_params())  
  6.   
  7. grad_d_a = value_and_grad(discriminator_forward_a, None, net_d_a.trainable_params())  
  8. grad_d_b = value_and_grad(discriminator_forward_b, None, net_d_b.trainable_params())  
  9.   
  10. # 计算生成器的梯度,反向传播更新参数  
  11. def train_step_g(img_a, img_b):  
  12.     net_d_a.set_grad(False)  
  13.     net_d_b.set_grad(False)  
  14.   
  15.     fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = generator_forward(img_a, img_b)  
  16.   
  17.     _, grads_g_a = grad_g_a(img_a, img_b)  
  18.     _, grads_g_b = grad_g_b(img_a, img_b)  
  19.     optimizer_rg_a(grads_g_a)  
  20.     optimizer_rg_b(grads_g_b)  
  21.   
  22.     return fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib  
  23.   
  24. # 计算判别器的梯度,反向传播更新参数  
  25. def train_step_d(img_a, img_b, fake_a, fake_b):  
  26.     net_d_a.set_grad(True)  
  27.     net_d_b.set_grad(True)  
  28.   
  29.     loss_d_a, grads_d_a = grad_d_a(img_a, fake_a)  
  30.     loss_d_b, grads_d_b = grad_d_b(img_b, fake_b)  
  31.   
  32.     loss_d = (loss_d_a + loss_d_b) * 0.5  
  33.   
  34.     optimizer_d_a(grads_d_a)  
  35.     optimizer_d_b(grads_d_b)  
  36.   
  37.     return loss_d  

        分析:从 mindspore 库中导入 value_and_grad 函数,用于计算函数的输出值和梯度。这里为生成器和判别器的相关函数创建了求梯度的实例。在 train_step_g 函数中,首先设置判别器 net_d_a 和 net_d_b 的梯度计算为 False ,即暂时不计算判别器的梯度。在 train_step_g 函数中,首先设置判别器 net_d_a 和 net_d_b 的梯度计算为 False ,即暂时不计算判别器的梯度。通过之前创建的梯度计算实例获取生成器的梯度,并使用相应的优化器 optimizer_rg_a 和 optimizer_rg_b 进行参数更新。在 train_step_d 函数中,设置判别器 net_d_a 和 net_d_b 的梯度计算为 True ,准备计算判别器的梯度。获取判别器的损失和梯度。计算总的判别器损失,并使用相应的优化器 optimizer_d_a 和 optimizer_d_b 进行参数更新。

        总的来说,这段代码定义了用于训练生成器和判别器的函数,通过计算梯度并使用优化器来更新网络的参数。

模型训练


        训练分为两个主要部分:训练判别器和训练生成器,在前文的判别器损失函数中,论文采用了最小二乘损失代替负对数似然目标。

        下面定义了生成器和判别器的训练过程:

        代码如下:

  1. import os  
  2. import time  
  3. import random  
  4. import numpy as np  
  5. from PIL import Image  
  6. from mindspore import Tensor, save_checkpoint  
  7. from mindspore import dtype  
  8.   
  9. # 由于时间原因,epochs设置为1,可根据需求进行调整  
  10. epochs = 1  
  11. save_step_num = 80  
  12. save_checkpoint_epochs = 1  
  13. save_ckpt_dir = './train_ckpt_outputs/'  
  14.   
  15. print('Start training!')  
  16.   
  17. for epoch in range(epochs):  
  18.     g_loss = []  
  19.     d_loss = []  
  20.     start_time_e = time.time()  
  21.     for step, data in enumerate(dataset.create_dict_iterator()):  
  22.         start_time_s = time.time()  
  23.         img_a = data["image_A"]  
  24.         img_b = data["image_B"]  
  25.         res_g = train_step_g(img_a, img_b)  
  26.         fake_a = res_g[0]  
  27.         fake_b = res_g[1]  
  28.   
  29.         res_d = train_step_d(img_a, img_b, image_pool(fake_a), image_pool(fake_b))  
  30.         loss_d = float(res_d.asnumpy())  
  31.         step_time = time.time() - start_time_s  
  32.   
  33.         res = []  
  34.         for item in res_g[2:]:  
  35.             res.append(float(item.asnumpy()))  
  36.         g_loss.append(res[0])  
  37.         d_loss.append(loss_d)  
  38.   
  39.         if step % save_step_num == 0:  
  40.             print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "  
  41.                   f"step:[{int(step):>4d}/{int(datasize):>4d}], "  
  42.                   f"time:{step_time:>3f}s,\n"  
  43.                   f"loss_g:{res[0]:.2f}, loss_d:{loss_d:.2f}, "  
  44.                   f"loss_g_a: {res[1]:.2f}, loss_g_b: {res[2]:.2f}, "  
  45.                   f"loss_c_a: {res[3]:.2f}, loss_c_b: {res[4]:.2f}, "  
  46.                   f"loss_idt_a: {res[5]:.2f}, loss_idt_b: {res[6]:.2f}")  
  47.   
  48.     epoch_cost = time.time() - start_time_e  
  49.     per_step_time = epoch_cost / datasize  
  50.     mean_loss_d, mean_loss_g = sum(d_loss) / datasize, sum(g_loss) / datasize  
  51.   
  52.     print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "  
  53.           f"epoch time:{epoch_cost:.2f}s, per step time:{per_step_time:.2f}, "  
  54.           f"mean_g_loss:{mean_loss_g:.2f}, mean_d_loss:{mean_loss_d :.2f}")  
  55.   
  56.     if epoch % save_checkpoint_epochs == 0:  
  57.         os.makedirs(save_ckpt_dir, exist_ok=True)  
  58.         save_checkpoint(net_rg_a, os.path.join(save_ckpt_dir, f"g_a_{epoch}.ckpt"))  
  59.         save_checkpoint(net_rg_b, os.path.join(save_ckpt_dir, f"g_b_{epoch}.ckpt"))  
  60.         save_checkpoint(net_d_a, os.path.join(save_ckpt_dir, f"d_a_{epoch}.ckpt"))  
  61.         save_checkpoint(net_d_b, os.path.join(save_ckpt_dir, f"d_b_{epoch}.ckpt"))  
  62.   
  63. print('End of training!')  

        分析:首先导入了所需的库和模块。

        设置了训练的轮数 epochs 、保存检查点的步长 save_step_num 、保存检查点的轮数间隔 save_checkpoint_epochs 以及保存检查点的目录 save_ckpt_dir 。

        打印出开始训练的提示信息。

        开始进行训练轮数的循环。初始化生成器损失列表 g_loss 和判别器损失列表 d_loss ,并记录每一轮训练开始的时间。

        在每一轮中,通过数据集的迭代器获取数据,并执行生成器的训练步骤,获取训练结果和生成的假图像。

        执行判别器的训练步骤,获取判别器的损失,并计算每一步的训练时间。

        处理生成器的训练结果,将相关损失值添加到对应的损失列表中。

        如果当前步数是保存步长的整数倍,打印出当前的训练信息,包括轮数、步数、训练时间以及各种损失值。

        计算每一轮的训练总时间、每步平均时间,以及判别器和生成器损失的平均值。

        打印出每一轮训练结束后的总结信息,包括轮数、训练总时间、每步平均时间以及平均损失值。

        如果当前轮数是保存检查点轮数间隔的整数倍,创建保存目录,并保存生成器和判别器的模型检查点。

        最后打印训练结束的提示信息。

        运行结果:

  1. Start training!  
  2. Epoch:[  1/  1], step:[   0/1019], time:140.084908s,  
  3. loss_g:20.59, loss_d:1.09, loss_g_a: 1.03, loss_g_b: 1.07, loss_c_a: 5.08, loss_c_b: 7.28, loss_idt_a: 2.55, loss_idt_b: 3.59  
  4. Epoch:[  1/  1], step:[  80/1019], time:0.459862s,  
  5. loss_g:9.79, loss_d:0.35, loss_g_a: 0.60, loss_g_b: 0.32, loss_c_a: 2.05, loss_c_b: 3.97, loss_idt_a: 1.03, loss_idt_b: 1.80  
  6. Epoch:[  1/  1], step:[ 160/1019], time:0.450440s,  
  7. loss_g:10.24, loss_d:0.38, loss_g_a: 0.54, loss_g_b: 0.65, loss_c_a: 3.35, loss_c_b: 2.68, loss_idt_a: 1.68, loss_idt_b: 1.33  
  8. Epoch:[  1/  1], step:[ 240/1019], time:0.450548s,  
  9. loss_g:10.80, loss_d:0.36, loss_g_a: 0.32, loss_g_b: 0.43, loss_c_a: 3.88, loss_c_b: 3.42, loss_idt_a: 1.33, loss_idt_b: 1.42  
  10. Epoch:[  1/  1], step:[ 320/1019], time:0.501738s,  
  11. loss_g:7.80, loss_d:0.30, loss_g_a: 0.30, loss_g_b: 0.41, loss_c_a: 1.57, loss_c_b: 3.54, loss_idt_a: 0.48, loss_idt_b: 1.50  
  12. Epoch:[  1/  1], step:[ 400/1019], time:0.459365s,  
  13. loss_g:6.60, loss_d:0.70, loss_g_a: 0.29, loss_g_b: 0.28, loss_c_a: 1.60, loss_c_b: 2.48, loss_idt_a: 0.88, loss_idt_b: 1.06  
  14. Epoch:[  1/  1], step:[ 480/1019], time:0.448666s,  
  15. loss_g:4.93, loss_d:0.62, loss_g_a: 0.16, loss_g_b: 0.59, loss_c_a: 1.27, loss_c_b: 1.73, loss_idt_a: 0.46, loss_idt_b: 0.72  
  16. Epoch:[  1/  1], step:[ 560/1019], time:0.469116s,  
  17. loss_g:5.46, loss_d:0.43, loss_g_a: 0.37, loss_g_b: 0.59, loss_c_a: 1.59, loss_c_b: 1.38, loss_idt_a: 0.94, loss_idt_b: 0.60  
  18. Epoch:[  1/  1], step:[ 640/1019], time:0.453548s,  
  19. loss_g:5.32, loss_d:0.34, loss_g_a: 0.44, loss_g_b: 0.40, loss_c_a: 1.52, loss_c_b: 1.57, loss_idt_a: 0.62, loss_idt_b: 0.77  
  20. Epoch:[  1/  1], step:[ 720/1019], time:0.460250s,  
  21. loss_g:5.53, loss_d:0.46, loss_g_a: 0.49, loss_g_b: 0.20, loss_c_a: 1.64, loss_c_b: 1.90, loss_idt_a: 0.64, loss_idt_b: 0.66  
  22. Epoch:[  1/  1], step:[ 800/1019], time:0.460121s,  
  23. loss_g:4.58, loss_d:0.29, loss_g_a: 0.37, loss_g_b: 0.69, loss_c_a: 1.34, loss_c_b: 1.16, loss_idt_a: 0.60, loss_idt_b: 0.42  
  24. Epoch:[  1/  1], step:[ 880/1019], time:0.453644s,  
  25. loss_g:6.31, loss_d:0.44, loss_g_a: 0.37, loss_g_b: 0.18, loss_c_a: 2.37, loss_c_b: 1.63, loss_idt_a: 1.09, loss_idt_b: 0.67  
  26. Epoch:[  1/  1], step:[ 960/1019], time:0.449156s,  
  27. loss_g:3.81, loss_d:0.26, loss_g_a: 0.42, loss_g_b: 0.35, loss_c_a: 0.73, loss_c_b: 1.38, loss_idt_a: 0.33, loss_idt_b: 0.60  
  28. Epoch:[  1/  1], epoch time:609.15s, per step time:0.60, mean_g_loss:7.00, mean_d_loss:0.45  
  29. End of training!  
  30. CPU times: user 19min 48s, sys: 5min 15s, total: 25min 3s  
  31. Wall time: 10min 10s  

模型推理


        下面我们通过加载生成器网络模型参数文件来对原图进行风格迁移,结果中第一行为原图,第二行为对应生成的结果图。

        代码如下:

  1. %%time  
  2. import os  
  3. from PIL import Image  
  4. import mindspore.dataset as ds  
  5. import mindspore.dataset.vision as vision  
  6. from mindspore import load_checkpoint, load_param_into_net  
  7.   
  8. # 加载权重文件  
  9. def load_ckpt(net, ckpt_dir):  
  10.     param_GA = load_checkpoint(ckpt_dir)  
  11.     load_param_into_net(net, param_GA)  
  12.   
  13. g_a_ckpt = './CycleGAN_apple2orange/ckpt/g_a.ckpt'  
  14. g_b_ckpt = './CycleGAN_apple2orange/ckpt/g_b.ckpt'  
  15.   
  16. load_ckpt(net_rg_a, g_a_ckpt)  
  17. load_ckpt(net_rg_b, g_b_ckpt)  
  18.   
  19. # 图片推理  
  20. fig = plt.figure(figsize=(112.5), dpi=100)  
  21. def eval_data(dir_path, net, a):  
  22.   
  23.     def read_img():  
  24.         for dir in os.listdir(dir_path):  
  25.             path = os.path.join(dir_path, dir)  
  26.             img = Image.open(path).convert('RGB')  
  27.             yield img, dir  
  28.   
  29.     dataset = ds.GeneratorDataset(read_img, column_names=["image""image_name"])  
  30.     trans = [vision.Resize((256256)), vision.Normalize(mean=[0.5 * 255] * 3, std=[0.5 * 255] * 3), vision.HWC2CHW()]  
  31.     dataset = dataset.map(operations=trans, input_columns=["image"])  
  32.     dataset = dataset.batch(1)  
  33.     for i, data in enumerate(dataset.create_dict_iterator()):  
  34.         img = data["image"]  
  35.         fake = net(img)  
  36.         fake = (fake[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((120))  
  37.         img = (img[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((120))  
  38.   
  39.         fig.add_subplot(28, i+1+a)  
  40.         plt.axis("off")  
  41.         plt.imshow(img.asnumpy())  
  42.   
  43.         fig.add_subplot(28, i+9+a)  
  44.         plt.axis("off")  
  45.         plt.imshow(fake.asnumpy())  
  46.   
  47. eval_data('./CycleGAN_apple2orange/predict/apple', net_rg_a, 0)  
  48. eval_data('./CycleGAN_apple2orange/predict/orange', net_rg_b, 4)  
  49. plt.show()

        分析:首先,设置了代码运行的计时环境,并导入了所需的库和模块。

        定义了一个用于加载模型检查点权重的函数 load_ckpt 。

        指定了要加载的生成器 net_rg_a 和 net_rg_b 的检查点文件路径。

        调用 load_ckpt 函数加载相应的权重到生成器模型中。

        创建了一个用于绘制推理结果的图形,并定义了一个用于进行图片推理的函数 eval_data 。

        在 eval_data 函数内部定义了一个用于读取指定目录下图片的生成器函数 read_img 。

        通过读取图片生成数据集,并对图片进行预处理操作(如调整大小、归一化和格式转换),然后将数据集进行批处理。

        对数据集进行迭代,通过生成器模型进行推理得到生成的假图片,并对原始图片和生成的假图片进行数据处理。

        将原始图片和生成的假图片添加到图形的子图中进行展示。

        分别对苹果和橙子的图片目录进行推理,并展示结果。

        总的来说,这段代码主要实现了加载生成器模型的权重,并对指定目录下的苹果和橙子图片进行推理,将原始图片和生成的图片进行展示。

        运行结果:

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/正经夜光杯/article/detail/844412
推荐阅读
相关标签
  

闽ICP备14008679号