当前位置:   article > 正文

【Pytorch神经网络实战案例】25 (带数据增强)基于迁移学习识别多种鸟类(CUB-200数据集)_randaugmen 是什么

randaugmen 是什么

1 数据增强

在目前分类效果最好的EficientNet系列模型中,EfficientNet-B7版本的模型就是使用随机数据增强方法训练而成的。

RandAugment方法也是目前主流的数据增强方法,用RandAugment方法进行训练,会使模型的精度得到提升。

2 RandAugment

2.1 RandAugment方法简介

RandAugment方法是一种新的数据增强方法,它比自动数据增强(AutOAugment)方法更简单、更好用。它可以在原有的训练框架中,直接对AutoAugment方法进行替换。

2.1.1 Tip

AuoAugment方法包含30多个参数,可以对图片数据进行各种变换(参见arXiv网站上编号为1805.09501的论文)。

2.2 RandAugment方法的构成

RandAugment方法是在AutoAugment方法的基础之上,将30多个参数进行策略级的优化管理,使这30多个参数被简化成两个参数:图片的N次变换和每次变换的强度M。其中每次变换的强度M,取值为0~10(只取整数),表示使原有图片增强失真的大小。

RandAugment方法以结果为导向,使数据增强过程更加面向用。在减少AutoAugment的运算消耗的同时,又使增强的效果变得可控。详细内容可以参考相关论文(参见arXⅳ网站上编号为1909.13719的论文)。

2.2 代码获取

  1. https://github.com/heartInsert/randaugment
  2. # 只有一个代码文件Rand_Augment,py,将其下载后,直接引入代码即可使用。

3 本节案例(带有数据增强的识别)

3.1 案例简介

使用迁移学习对预训练模型进行微调的基础上实现数据增强,让其学习鸟类数据集,实现对多种鸟类进行识别。

3.2 代码实现:load_data函数加载图片名称与标签的加载----Transfer_bird2_Augmentation.py(第1部分)

  1. import glob
  2. import numpy as np
  3. from PIL import Image
  4. import matplotlib.pyplot as plt #plt 用于显示图片
  5. import torch
  6. import torch.nn as nn
  7. import torch.optim as optim
  8. from torch.optim import lr_scheduler
  9. from torch.utils.data import Dataset,DataLoader
  10. import torchvision
  11. import torchvision.models as model
  12. from torchvision.transforms import ToPILImage
  13. import torchvision.transforms as transforms
  14. import os
  15. os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
  16. # 1.1 实现load_data函数加载图片名称与标签的加载,并使用torch.utils.data接口将其封装成程序可用的数据集类OwnDataset。
  17. def load_dir(directory,labstart=0): # 获取所有directory中的所有图与标签
  18. # 返回path指定的文件夹所包含的文件或文件名的名称列表
  19. strlabels = os.listdir(directory)
  20. # 对标签进行排序,以便训练和验证按照相同的顺序进行:在不同的操作系统中,加载文件夹的顺序可能不同。目录不同的情况会导致在不同的操作系统中,模型的标签出现串位的现象。所以需要对文件夹进行排序,保证其顺序的一致性。
  21. strlabels.sort()
  22. # 创建文件标签列表
  23. file_labels = []
  24. for i,label in enumerate(strlabels):
  25. print(label)
  26. jpg_names = glob.glob(os.path.join(directory,label,"*.jpg"))
  27. print(jpg_names)
  28. # 加入列表
  29. file_labels.extend(zip(jpg_names, [i + labstart] * len(jpg_names)))
  30. return file_labels,strlabels
  31. def load_data(dataset_path): # 定义函数load_data函数完成对数据集中图片文件名称和标签的加载。
  32. # 该函数可以实现两层文件夹的嵌套结构。其中,外层结构使用load_data函数进行遍历,内层结构使用load_dir函进行遍历。
  33. sub_dir = sorted(os.listdir(dataset_path)) # 跳过子文件夹:在不同的操作系统中,加载文件夹的顺序可能不同。目录不同的情况会导致在不同的操作系统中,模型的标签出现串位的现象。所以需要对文件夹进行排序,保证其顺序的一致性。
  34. start = 1 # 第0类是none
  35. tfile_lables,tstrlabels = [],['none'] # 在制作标签时,人为地在前面添加了一个序号为0的none类。这是一个训练图文类模型的技巧,为了区分模型输出值是0和预测值是0这两种情况。
  36. for i in sub_dir:
  37. directory = os.path.join(dataset_path,i)
  38. if os.path.isdir(directory) == False: # 只处理文件夹中的数据
  39. print(directory)
  40. continue
  41. file_labels,strlables = load_dir(directory,labstart=start)
  42. tfile_lables.extend(file_labels)
  43. tstrlabels.extend(strlables)
  44. start = len(strlables)
  45. # 将数据路径与标签解压缩,把数据路径和标签解压缩出来
  46. filenames,labels = zip(*tfile_lables)
  47. return filenames, labels, tstrlabels

3.3 代码实现:自定义数据集类OwnDataset----Transfer_bird2_Augmentation.py(第2部分)

  1. # 1.2 实现自定义数据集OwnDataset
  2. def default_loader(path) : # 定义函数加载图片
  3. return Image.open(path).convert('RGB')
  4. class OwnDataset(Dataset): # 复用性较强,可根据自己的数据集略加修改使用
  5. # 在PyTorch中,提供了一个torch.utis.data接口,可以用来对数据集进行封装。在实现时,只需要继承torch.utis.data.Dataset类,并重载其__gettem__方法。
  6. # 在使用时,框架会向__gettem__方法传入索引index,在__gettem__方法内部根据指定index加载数据,并返回。
  7. def __init__(self,img_dir,labels,indexlist=None,transform=transforms.ToTensor(),loader=default_loader,cache=True): # 初始化
  8. self.labels = labels # 存放标签
  9. self.img_dir = img_dir # 样本图片文件名
  10. self.transform = transform # 预处理方法
  11. self.loader = loader # 加载方法
  12. self.cache = cache # 缓存标志
  13. if indexlist is None: # 要加载的数据序列
  14. self.indexlist = list(range(len(self.img_dir)))
  15. else:
  16. self.indexlist = indexlist
  17. self.data = [None] * len(self.indexlist) # 存放样本图片
  18. def __getitem__(self, idx): # 加载指定索引数据
  19. if self.data[idx] is None: # 第一次加载
  20. data = self.loader(self.img_dir[self.indexlist[idx]])
  21. if self.transform:
  22. data = self.transform(data)
  23. else:
  24. data = self.data[idx]
  25. if self.cache: # 保存到缓存里
  26. self.data[idx] = data
  27. return data,self.labels[self.indexlist[idx]]
  28. def __len__(self): # 计算数据集长度
  29. return len(self.indexlist)

3.4 代码实战:测试数据集----Transfer_bird2_Augmentation.py(第3部分)【数据增强模块】

  1. # 1.3 测试数据集:在完成数据集的制作之后,编写代码对其进行测试。
  2. # 数据增强模块
  3. from Rand_Augment import Rand_Augment
  4. data_transform = { #定义数据的预处理方法
  5. 'train':transforms.Compose([
  6. Rand_Augment(), # 数据增强的方法带入 仅此一处修改
  7. transforms.RandomResizedCrop(224),
  8. transforms.RandomHorizontalFlip(),
  9. transforms.ToTensor(),
  10. transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
  11. ]),
  12. 'val':transforms.Compose([
  13. transforms.Resize(256),
  14. transforms.CenterCrop(224),
  15. transforms.ToTensor(),
  16. transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
  17. ]),
  18. }
  19. def Reduction_img(tensor,mean,std): #还原图片,实现了图片归一化的逆操作,显示数据集中的原始图片。
  20. dtype = tensor.dtype
  21. mean = torch.as_tensor(mean,dtype=dtype,device=tensor.device)
  22. std = torch.as_tensor(std,dtype=dtype,device=tensor.device)
  23. tensor.mul_(std[:,None,None]).add_(mean[:,None,None]) # 还原操作
  24. dataset_path = r'./data/cub200/' # 加载数据集路径
  25. filenames,labels,classes = load_data(dataset_path) # 调用load_data函数对数据集中图片文件名称和标签进行加载,其返回对象classes中包含全部的类名。
  26. # 打乱数据顺序
  27. # 110-115行对数据文件列表的序号进行乱序划分,分为测试数据集和训练数集两个索引列表。该索引列表会传入OwnDataset类做成指定的数据集。
  28. np.random.seed(0)
  29. label_shuffle_index = np.random.permutation(len(labels))
  30. label_train_num = (len(labels)//10) * 8 # 划分训练数据集和测试数据集
  31. train_list = label_shuffle_index[0:label_train_num]
  32. test_list = label_shuffle_index[label_train_num:] # 没带:
  33. train_dataset = OwnDataset(filenames,labels,train_list,data_transform['train'])# 实例化训练数据集
  34. val_dataset = OwnDataset(filenames,labels,test_list,data_transform['val']) # 实例化测试数据集
  35. # 实例化批次数据集:OwnDataset类所定义的数据集,其使用方法与PyTorch中的内置数据集的使用方法完全一致,配合DataLoader接口即可生成可以进行训练或测试的批次数据。具体代码如下。
  36. train_loader = DataLoader(dataset=train_dataset,batch_size=32,shuffle=True)
  37. val_loader = DataLoader(dataset=val_dataset,batch_size=32,shuffle=True)
  38. sample = iter(train_loader) # 获取一批次数据,进行测试
  39. images,labels = sample.next()
  40. print("样本形状",np.shape(images))
  41. print("标签个数",len(classes))
  42. mulimgs = torchvision.utils.make_grid(images[:10],nrow=10) # 拼接多张图片
  43. Reduction_img(mulimgs,[0.485,0.456,0.406],[0.229,0.224,0.225])
  44. _img = ToPILImage()(mulimgs) # 将张量转化为图片
  45. plt.axis('off')
  46. plt.imshow(_img) # 显示
  47. plt.show()
  48. print(','.join('%5s' % classes[labels[j]] for j in range(len(images[:10]))))

输出:

样本形状 torch.Size([32, 3, 224, 224])
标签个数 6

输出数据集中的10个图片

3.5 代码实战:获取并改造ResNet模型----Transfer_bird2_Augmentation.py(第4部分)

  1. # 1.4 获取并改造ResNet模型:获取ResNet模型,并加载预训练模型的权重。将其最后一层(输出层)去掉,换成一个全新的全连接层,该全连接层的输出节点数与本例分类数相同。
  2. # 指定设备
  3. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  4. print(device)
  5. # get_ResNet函数,获取预训练模型,可指定pretrained=True来实现自动下载预训练模型,也可指定loadfile来从本地路径加载预训练模型。
  6. def get_ResNet(classes,pretrained=True,loadfile=None):
  7. ResNet = model.resnet101(pretrained) # 自动下载官方的预训练模型
  8. if loadfile != None:
  9. ResNet.load_state_dict(torch.load(loadfile)) # 加载本地模型
  10. # 将所有的参数层进行冻结:设置模型仅最后一层可以进行训练,使模型只针对最后一层进行微调。
  11. for param in ResNet.parameters():
  12. param.requires_grad = False
  13. # 输出全连接层的信息
  14. print(ResNet.fc)
  15. x = ResNet.fc.in_features # 获取全连接层的输入
  16. ResNet.fc = nn.Linear(x,len(classes)) # 定义一个新的全连接层
  17. print(ResNet.fc) # 最后输出新的模型
  18. return ResNet
  19. ResNet = get_ResNet(classes) # 实例化模型
  20. ResNet.to(device=device)

3.6 代码实战:定义损失函数、训练函数及测试函数,对模型的最后一层进行微调----Transfer_bird2_Augmentation.py(第5部分)

  1. # 1.5 定义损失函数、训练函数及测试函数,对模型的最后一层进行微调。
  2. criterion = nn.CrossEntropyLoss()
  3. # 指定新加的全连接层的学习率
  4. optimizer = torch.optim.Adam([{'params':ResNet.fc.parameters()}],lr=0.01)
  5. def train(model,device,train_loader,epoch,optimizer): # 定义训练函数
  6. model.train()
  7. allloss = []
  8. for batch_idx,data in enumerate(train_loader):
  9. x,y = data
  10. x = x.to(device)
  11. y = y.to(device)
  12. optimizer.zero_grad()
  13. y_hat = model(x)
  14. loss = criterion(y_hat,y)
  15. loss.backward()
  16. allloss.append(loss.item())
  17. optimizer.step()
  18. print('Train Epoch:{}\t Loss:{:.6f}'.format(epoch,np.mean(allloss))) # 输出训练结果
  19. def test(model,device,val_loader): # 定义测试函数
  20. model.eval()
  21. test_loss = []
  22. correct = []
  23. with torch.no_grad(): # 使模型在运行时不进行梯度跟踪,可以减少模型运行时对内存的占用。
  24. for i,data in enumerate(val_loader):
  25. x, y = data
  26. x = x.to(device)
  27. y = y.to(device)
  28. y_hat = model(x)
  29. test_loss.append(criterion(y_hat,y).item()) # 收集损失函数
  30. pred = y_hat.max(1,keepdim=True)[1] # 获取预测结果
  31. correct.append(pred.eq(y.view_as(pred)).sum().item()/pred.shape[0]) # 收集精确度
  32. print('\nTest:Average loss:{:,.4f},Accuracy:({:,.0f}%)\n'.format(np.mean(test_loss),np.mean(correct)*100)) # 输出测试结果
  33. # 迁移学习的两个步骤如下
  34. if __name__ == '__main__':
  35. # 迁移学习步骤①:固定预训练模型的特征提取部分,只对最后一层进行训练,使其快速收敛。
  36. firstmodepth = './data/cub200/firstmodepth_1.pth' # 定义模型文件的地址
  37. if os.path.exists(firstmodepth) == False:
  38. print("—————————固定预训练模型的特征提取部分,只对最后一层进行训练,使其快速收敛—————————")
  39. for epoch in range(1,2): # 迭代两次
  40. train(ResNet,device,train_loader,epoch,optimizer)
  41. test(ResNet,device,val_loader)
  42. # 保存模型
  43. torch.save(ResNet.state_dict(),firstmodepth)

3.7 代码实战:使用退化学习率对模型进行全局微调----Transfer_bird2_Augmentation.py(第6部分)

  1. # 1.6 使用退化学习率对模型进行全局微调
  2. #迁移学习步骤②:使用较小的学习率,对全部模型进行训练,并对每层的权重进行细微的调节,即将模型的每层权重都设为可训练,并定义带有退化学习率的优化器。(1.6部分)
  3. secondmodepth = './data/cub200/firstmodepth_2.pth'
  4. optimizer2 = optim.SGD(ResNet.parameters(),lr=0.001,momentum=0.9) # 第198行代码定义带有退化学习率的SGD优化器。该优化器常用来对模型进行手动微调。有实验表明,使用经过手动调节的SGD优化器,在训练模型的后期效果优于Adam优化器。
  5. exp_lr_scheduler = lr_scheduler.StepLR(optimizer2,step_size=2,gamma=0.9) # 由于退化学习率会在训练过程中不断地变小,为了防止学习率过小,最终无法进行权重需要对其设置最小值。当学习率低于该值时,停止对退化学习率的操作。
  6. for param in ResNet.parameters(): # 所有参数设计为可训练
  7. param.requires_grad = True
  8. if os.path.exists(secondmodepth):
  9. ResNet.load_state_dict(torch.load(secondmodepth)) # 加载本地模型
  10. else:
  11. ResNet.load_state_dict(torch.load(firstmodepth)) # 加载本地模型
  12. print("____使用较小的学习率,对全部模型进行训练,定义带有退化学习率的优化器______")
  13. for epoch in range(1,100):
  14. train(ResNet,device,train_loader,epoch,optimizer2)
  15. if optimizer2.state_dict()['param_groups'][0]['lr'] > 0.00001:
  16. exp_lr_scheduler.step()
  17. print("___lr:",optimizer2.state_dict()['param_groups'][0]['lr'])
  18. test(ResNet,device,val_loader)
  19. # 保存模型
  20. torch.save(ResNet.state_dict(),secondmodepth)

4 代码总览Transfer_bird2_Augmentation.py

  1. import glob
  2. import numpy as np
  3. from PIL import Image
  4. import matplotlib.pyplot as plt #plt 用于显示图片
  5. import torch
  6. import torch.nn as nn
  7. import torch.optim as optim
  8. from torch.optim import lr_scheduler
  9. from torch.utils.data import Dataset,DataLoader
  10. import torchvision
  11. import torchvision.models as model
  12. from torchvision.transforms import ToPILImage
  13. import torchvision.transforms as transforms
  14. import os
  15. os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
  16. # 1.1 实现load_data函数加载图片名称与标签的加载,并使用torch.utils.data接口将其封装成程序可用的数据集类OwnDataset。
  17. def load_dir(directory,labstart=0): # 获取所有directory中的所有图与标签
  18. # 返回path指定的文件夹所包含的文件或文件名的名称列表
  19. strlabels = os.listdir(directory)
  20. # 对标签进行排序,以便训练和验证按照相同的顺序进行:在不同的操作系统中,加载文件夹的顺序可能不同。目录不同的情况会导致在不同的操作系统中,模型的标签出现串位的现象。所以需要对文件夹进行排序,保证其顺序的一致性。
  21. strlabels.sort()
  22. # 创建文件标签列表
  23. file_labels = []
  24. for i,label in enumerate(strlabels):
  25. print(label)
  26. jpg_names = glob.glob(os.path.join(directory,label,"*.jpg"))
  27. print(jpg_names)
  28. # 加入列表
  29. file_labels.extend(zip(jpg_names, [i + labstart] * len(jpg_names)))
  30. return file_labels,strlabels
  31. def load_data(dataset_path): # 定义函数load_data函数完成对数据集中图片文件名称和标签的加载。
  32. # 该函数可以实现两层文件夹的嵌套结构。其中,外层结构使用load_data函数进行遍历,内层结构使用load_dir函进行遍历。
  33. sub_dir = sorted(os.listdir(dataset_path)) # 跳过子文件夹:在不同的操作系统中,加载文件夹的顺序可能不同。目录不同的情况会导致在不同的操作系统中,模型的标签出现串位的现象。所以需要对文件夹进行排序,保证其顺序的一致性。
  34. start = 1 # 第0类是none
  35. tfile_lables,tstrlabels = [],['none'] # 在制作标签时,人为地在前面添加了一个序号为0的none类。这是一个训练图文类模型的技巧,为了区分模型输出值是0和预测值是0这两种情况。
  36. for i in sub_dir:
  37. directory = os.path.join(dataset_path,i)
  38. if os.path.isdir(directory) == False: # 只处理文件夹中的数据
  39. print(directory)
  40. continue
  41. file_labels,strlables = load_dir(directory,labstart=start)
  42. tfile_lables.extend(file_labels)
  43. tstrlabels.extend(strlables)
  44. start = len(strlables)
  45. # 将数据路径与标签解压缩,把数据路径和标签解压缩出来
  46. filenames,labels = zip(*tfile_lables)
  47. return filenames, labels, tstrlabels
  48. # 1.2 实现自定义数据集OwnDataset
  49. def default_loader(path) : # 定义函数加载图片
  50. return Image.open(path).convert('RGB')
  51. class OwnDataset(Dataset): # 复用性较强,可根据自己的数据集略加修改使用
  52. # 在PyTorch中,提供了一个torch.utis.data接口,可以用来对数据集进行封装。在实现时,只需要继承torch.utis.data.Dataset类,并重载其__gettem__方法。
  53. # 在使用时,框架会向__gettem__方法传入索引index,在__gettem__方法内部根据指定index加载数据,并返回。
  54. def __init__(self,img_dir,labels,indexlist=None,transform=transforms.ToTensor(),loader=default_loader,cache=True): # 初始化
  55. self.labels = labels # 存放标签
  56. self.img_dir = img_dir # 样本图片文件名
  57. self.transform = transform # 预处理方法
  58. self.loader = loader # 加载方法
  59. self.cache = cache # 缓存标志
  60. if indexlist is None: # 要加载的数据序列
  61. self.indexlist = list(range(len(self.img_dir)))
  62. else:
  63. self.indexlist = indexlist
  64. self.data = [None] * len(self.indexlist) # 存放样本图片
  65. def __getitem__(self, idx): # 加载指定索引数据
  66. if self.data[idx] is None: # 第一次加载
  67. data = self.loader(self.img_dir[self.indexlist[idx]])
  68. if self.transform:
  69. data = self.transform(data)
  70. else:
  71. data = self.data[idx]
  72. if self.cache: # 保存到缓存里
  73. self.data[idx] = data
  74. return data,self.labels[self.indexlist[idx]]
  75. def __len__(self): # 计算数据集长度
  76. return len(self.indexlist)
  77. # 1.3 测试数据集:在完成数据集的制作之后,编写代码对其进行测试。
  78. # 数据增强模块
  79. from Rand_Augment import Rand_Augment
  80. data_transform = { #定义数据的预处理方法
  81. 'train':transforms.Compose([
  82. Rand_Augment(), # 数据增强的方法带入 仅此一处修改
  83. transforms.RandomResizedCrop(224),
  84. transforms.RandomHorizontalFlip(),
  85. transforms.ToTensor(),
  86. transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
  87. ]),
  88. 'val':transforms.Compose([
  89. transforms.Resize(256),
  90. transforms.CenterCrop(224),
  91. transforms.ToTensor(),
  92. transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
  93. ]),
  94. }
  95. def Reduction_img(tensor,mean,std): #还原图片,实现了图片归一化的逆操作,显示数据集中的原始图片。
  96. dtype = tensor.dtype
  97. mean = torch.as_tensor(mean,dtype=dtype,device=tensor.device)
  98. std = torch.as_tensor(std,dtype=dtype,device=tensor.device)
  99. tensor.mul_(std[:,None,None]).add_(mean[:,None,None]) # 还原操作
  100. dataset_path = r'./data/cub200/' # 加载数据集路径
  101. filenames,labels,classes = load_data(dataset_path) # 调用load_data函数对数据集中图片文件名称和标签进行加载,其返回对象classes中包含全部的类名。
  102. # 打乱数据顺序
  103. # 110-115行对数据文件列表的序号进行乱序划分,分为测试数据集和训练数集两个索引列表。该索引列表会传入OwnDataset类做成指定的数据集。
  104. np.random.seed(0)
  105. label_shuffle_index = np.random.permutation(len(labels))
  106. label_train_num = (len(labels)//10) * 8 # 划分训练数据集和测试数据集
  107. train_list = label_shuffle_index[0:label_train_num]
  108. test_list = label_shuffle_index[label_train_num:] # 没带:
  109. train_dataset = OwnDataset(filenames,labels,train_list,data_transform['train'])# 实例化训练数据集
  110. val_dataset = OwnDataset(filenames,labels,test_list,data_transform['val']) # 实例化测试数据集
  111. # 实例化批次数据集:OwnDataset类所定义的数据集,其使用方法与PyTorch中的内置数据集的使用方法完全一致,配合DataLoader接口即可生成可以进行训练或测试的批次数据。具体代码如下。
  112. train_loader = DataLoader(dataset=train_dataset,batch_size=32,shuffle=True)
  113. val_loader = DataLoader(dataset=val_dataset,batch_size=32,shuffle=True)
  114. sample = iter(train_loader) # 获取一批次数据,进行测试
  115. images,labels = sample.next()
  116. print("样本形状",np.shape(images))
  117. print("标签个数",len(classes))
  118. mulimgs = torchvision.utils.make_grid(images[:10],nrow=10) # 拼接多张图片
  119. Reduction_img(mulimgs,[0.485,0.456,0.406],[0.229,0.224,0.225])
  120. _img = ToPILImage()(mulimgs) # 将张量转化为图片
  121. plt.axis('off')
  122. plt.imshow(_img) # 显示
  123. plt.show()
  124. print(','.join('%5s' % classes[labels[j]] for j in range(len(images[:10]))))
  125. # 1.4 获取并改造ResNet模型:获取ResNet模型,并加载预训练模型的权重。将其最后一层(输出层)去掉,换成一个全新的全连接层,该全连接层的输出节点数与本例分类数相同。
  126. # 指定设备
  127. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  128. print(device)
  129. # get_ResNet函数,获取预训练模型,可指定pretrained=True来实现自动下载预训练模型,也可指定loadfile来从本地路径加载预训练模型。
  130. def get_ResNet(classes,pretrained=True,loadfile=None):
  131. ResNet = model.resnet101(pretrained) # 自动下载官方的预训练模型
  132. if loadfile != None:
  133. ResNet.load_state_dict(torch.load(loadfile)) # 加载本地模型
  134. # 将所有的参数层进行冻结:设置模型仅最后一层可以进行训练,使模型只针对最后一层进行微调。
  135. for param in ResNet.parameters():
  136. param.requires_grad = False
  137. # 输出全连接层的信息
  138. print(ResNet.fc)
  139. x = ResNet.fc.in_features # 获取全连接层的输入
  140. ResNet.fc = nn.Linear(x,len(classes)) # 定义一个新的全连接层
  141. print(ResNet.fc) # 最后输出新的模型
  142. return ResNet
  143. ResNet = get_ResNet(classes) # 实例化模型
  144. ResNet.to(device=device)
  145. # 1.5 定义损失函数、训练函数及测试函数,对模型的最后一层进行微调。
  146. criterion = nn.CrossEntropyLoss()
  147. # 指定新加的全连接层的学习率
  148. optimizer = torch.optim.Adam([{'params':ResNet.fc.parameters()}],lr=0.01)
  149. def train(model,device,train_loader,epoch,optimizer): # 定义训练函数
  150. model.train()
  151. allloss = []
  152. for batch_idx,data in enumerate(train_loader):
  153. x,y = data
  154. x = x.to(device)
  155. y = y.to(device)
  156. optimizer.zero_grad()
  157. y_hat = model(x)
  158. loss = criterion(y_hat,y)
  159. loss.backward()
  160. allloss.append(loss.item())
  161. optimizer.step()
  162. print('Train Epoch:{}\t Loss:{:.6f}'.format(epoch,np.mean(allloss))) # 输出训练结果
  163. def test(model,device,val_loader): # 定义测试函数
  164. model.eval()
  165. test_loss = []
  166. correct = []
  167. with torch.no_grad(): # 使模型在运行时不进行梯度跟踪,可以减少模型运行时对内存的占用。
  168. for i,data in enumerate(val_loader):
  169. x, y = data
  170. x = x.to(device)
  171. y = y.to(device)
  172. y_hat = model(x)
  173. test_loss.append(criterion(y_hat,y).item()) # 收集损失函数
  174. pred = y_hat.max(1,keepdim=True)[1] # 获取预测结果
  175. correct.append(pred.eq(y.view_as(pred)).sum().item()/pred.shape[0]) # 收集精确度
  176. print('\nTest:Average loss:{:,.4f},Accuracy:({:,.0f}%)\n'.format(np.mean(test_loss),np.mean(correct)*100)) # 输出测试结果
  177. # 迁移学习的两个步骤如下
  178. if __name__ == '__main__':
  179. # 迁移学习步骤①:固定预训练模型的特征提取部分,只对最后一层进行训练,使其快速收敛。
  180. firstmodepth = './data/cub200/firstmodepth_1.pth' # 定义模型文件的地址
  181. if os.path.exists(firstmodepth) == False:
  182. print("—————————固定预训练模型的特征提取部分,只对最后一层进行训练,使其快速收敛—————————")
  183. for epoch in range(1,2): # 迭代两次
  184. train(ResNet,device,train_loader,epoch,optimizer)
  185. test(ResNet,device,val_loader)
  186. # 保存模型
  187. torch.save(ResNet.state_dict(),firstmodepth)
  188. # 1.6 使用退化学习率对模型进行全局微调
  189. #迁移学习步骤②:使用较小的学习率,对全部模型进行训练,并对每层的权重进行细微的调节,即将模型的每层权重都设为可训练,并定义带有退化学习率的优化器。(1.6部分)
  190. secondmodepth = './data/cub200/firstmodepth_2.pth'
  191. optimizer2 = optim.SGD(ResNet.parameters(),lr=0.001,momentum=0.9) # 第198行代码定义带有退化学习率的SGD优化器。该优化器常用来对模型进行手动微调。有实验表明,使用经过手动调节的SGD优化器,在训练模型的后期效果优于Adam优化器。
  192. exp_lr_scheduler = lr_scheduler.StepLR(optimizer2,step_size=2,gamma=0.9) # 由于退化学习率会在训练过程中不断地变小,为了防止学习率过小,最终无法进行权重需要对其设置最小值。当学习率低于该值时,停止对退化学习率的操作。
  193. for param in ResNet.parameters(): # 所有参数设计为可训练
  194. param.requires_grad = True
  195. if os.path.exists(secondmodepth):
  196. ResNet.load_state_dict(torch.load(secondmodepth)) # 加载本地模型
  197. else:
  198. ResNet.load_state_dict(torch.load(firstmodepth)) # 加载本地模型
  199. print("____使用较小的学习率,对全部模型进行训练,定义带有退化学习率的优化器______")
  200. for epoch in range(1,100):
  201. train(ResNet,device,train_loader,epoch,optimizer2)
  202. if optimizer2.state_dict()['param_groups'][0]['lr'] > 0.00001:
  203. exp_lr_scheduler.step()
  204. print("___lr:",optimizer2.state_dict()['param_groups'][0]['lr'])
  205. test(ResNet,device,val_loader)
  206. # 保存模型
  207. torch.save(ResNet.state_dict(),secondmodepth)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/248977
推荐阅读
相关标签
  

闽ICP备14008679号