当前位置:   article > 正文

Pytroch入坑 3. 自己的人脸数据+迁移学习(resnet18)_torch gpu 人脸识别

torch gpu 人脸识别

0.前言

之前是使用了mnist数据,且网络结构比较简单,针对自己的数据,如何使用更复杂、经典的网络呢?

1.数据集

目标是人脸识别,可以看做一个多分类问题,本次实验的数据集为ferest,共200个人,1400张3*80*80图片,比较小。


分为 train 和 val两个目录,每个目录下都有200个子目录。

资源可下载  

https://download.csdn.net/download/sinat_37787331/10383836

注意:训练和测试的目录名字和数量必须保持一致,子目录内可以没有图片。

附批量删除、批量改格式的代码

  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. def del_files(path):
  5. for root , dirs, files in os.walk(path):
  6. for name in files:
  7. if name.endswith(".png"):
  8. os.remove(os.path.join(root, name))
  9. print ("Delete File: " + os.path.join(root, name))
  10. # test
  11. if __name__ == "__main__":
  12. path = '/home/syj/Documents/datas/2'
  13. del_files(path)
  14. '''
  15. gai hou zhui
  16. #!/usr/bin/python
  17. # -*- coding: utf-8 -*-
  18. import os
  19. def model_extentsion(path,before_ext,ext):
  20. for name in os.listdir(path):
  21. full_path=os.path.join(path,name)
  22. if os.path.isfile(full_path):
  23. split_path=os.path.splitext(full_path)
  24. pwd_name=split_path[0]
  25. pwd_ext=split_path[1]
  26. before_ext1="."+before_ext
  27. if pwd_ext == before_ext1:
  28. ext1="."+ext
  29. pwd_name+=ext1
  30. re_name=os.path.join(path,pwd_name)
  31. os.renames(full_path, re_name)
  32. else:
  33. model_extentsion(full_path,before_ext,ext)
  34. model_extentsion("/home/syj/Documents/datas/Feret/train",'tif', "png")
  35. '''

2.数据加载

这次加载的是自己的数据,大体分为两种

第一种:图片文件夹+txt文档

可借鉴   http://www.bubuko.com/infodetail-2304938.html

第二种:训练集和测试集分开,且每一类文件都放在同一子目录下。本文采用这种方法

  1. # 数据人脸
  2. train_data = torchvision.datasets.ImageFolder('/home/syj/Documents/datas/Feret/train',
  3. transform=transforms.Compose([
  4. transforms.Resize(28),
  5. transforms.ToTensor(),
  6. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  7. ]))
  8. # 批训练 50samples, 1 channel, 28x28 (50, 1, 28, 28)
  9. train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
  10. test_data = torchvision.datasets.ImageFolder('/home/syj/Documents/datas/Feret/val',
  11. transform=transforms.Compose([
  12. transforms.Resize(28),
  13. transforms.ToTensor(),
  14. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  15. ]))
  16. test_loader = Data.DataLoader(dataset=test_data, batch_size=20, shuffle=True)
  17. ‘’‘
  18. data_transforms = {
  19. 'train': transforms.Compose([
  20. transforms.RandomResizedCrop(224),
  21. transforms.RandomHorizontalFlip(),
  22. transforms.ToTensor(),
  23. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  24. ]),
  25. 'val': transforms.Compose([
  26. transforms.Resize(256),
  27. transforms.CenterCrop(224), ##224*224为resnet18输入图片尺寸
  28. transforms.ToTensor(),
  29. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  30. ]), #归一化
  31. }
  32. ’‘’

主要是通过  torchvision.datasets.ImageFolder 这个函数实现的,很方便 

具体的归一化等操作介绍可以参考     https://blog.csdn.net/Hungryof/article/details/76649006


3.迁移学习,加载resnet18模型,并进行fine-tuning

官方提供了许多经典的模型,如alnex,vgg,resnet,并且有训练过的参数,可以用来迁移学习

  1. # model_ft = models.resnet18(pretrained=True)
  2. # num_ftrs = model_ft.fc.in_features
  3. # model_ft.fc = nn.Linear(num_ftrs, 200)

三行代码就搭建好了网络,会自动下载resnet18,只是把最后一层fc层由1000(Imaginenet)改为200就行了


4.模型保存和加载

有两种方法,一种只保存参数,一种全保存,后者简单但存储量大,我用的是后者

  1. model_ft = torch.load('/home/syj/Documents/model/resnet18_0.003.pkl')
  2. #torch.save(model_ft, '/home/syj/Documents/model/resnet18_0.003.pkl')

5.结果

我跑了9个epoch,200类的acc在72%左右,接近理论,花了4分钟(gt940m)

可以参考    http://www.cnblogs.com/denny402/p/7520063.html


6.完整代码

  1. from __future__ import print_function, division
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. from torch.optim import lr_scheduler
  6. from torch.autograd import Variable
  7. import torchvision
  8. from torchvision import datasets, models, transforms
  9. import time
  10. import os
  11. import matplotlib as mpl
  12. import matplotlib.pyplot as plt
  13. def train_model(model, criterion, optimizer, scheduler, num_epochs=1):
  14. since = time.time()
  15. best_model_wts = model.state_dict()
  16. best_acc = 0.0
  17. for epoch in range(num_epochs):
  18. print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  19. print('-' * 10)
  20. # Each epoch has a training and validation phase
  21. for phase in ['train', 'val']:
  22. if phase == 'train':
  23. scheduler.step()
  24. model.train(True) # Set model to training mode
  25. else:
  26. model.train(False) # Set model to evaluate mode
  27. running_loss = 0.0
  28. running_corrects = 0
  29. # Iterate over data.
  30. for data in dataloders[phase]:
  31. # get the inputs
  32. inputs, labels = data
  33. # wrap them in Variable
  34. if use_gpu:
  35. inputs = Variable(inputs.cuda())
  36. labels = Variable(labels.cuda())
  37. else:
  38. inputs, labels = Variable(inputs), Variable(labels)
  39. # zero the parameter gradients
  40. optimizer.zero_grad()
  41. # forward
  42. outputs = model(inputs)
  43. _, preds = torch.max(outputs.data, 1)
  44. loss = criterion(outputs, labels)
  45. # backward + optimize only if in training phase
  46. if phase == 'train':
  47. loss.backward()
  48. optimizer.step()
  49. # statistics
  50. running_loss += loss.data[0]
  51. running_corrects += torch.sum(preds == labels.data)
  52. if phase == 'train':
  53. train_loss.append(loss.data[0] / 15)
  54. train_acc.append(torch.sum(preds == labels.data) / 15)
  55. else:
  56. test_loss.append(loss.data[0] / 15)
  57. test_acc.append(torch.sum(preds == labels.data) / 15)
  58. epoch_loss = running_loss / dataset_sizes[phase]
  59. epoch_acc = running_corrects / dataset_sizes[phase]
  60. print('{} Loss {:.4f} Acc: {:.4f}'.format(
  61. phase, epoch_loss, epoch_acc))
  62. # deep copy the model
  63. if phase == 'val' and epoch_acc > best_acc:
  64. best_acc = epoch_acc
  65. best_model_wts = model.state_dict()
  66. time_elapsed = time.time() - since
  67. print('Training complete in {:.0f}m {:.0f}s'.format(
  68. time_elapsed // 60, time_elapsed % 60))
  69. print('Best val Acc: {:4f}'.format(best_acc))
  70. # load best model weights
  71. model.load_state_dict(best_model_wts)
  72. return model
  73. if __name__ == '__main__':
  74. # data_transform, pay attention that the input of Normalize() is Tensor and the input of RandomResizedCrop() or RandomHorizontalFlip() is PIL Image
  75. data_transforms = {
  76. 'train': transforms.Compose([
  77. transforms.RandomResizedCrop(224),
  78. transforms.RandomHorizontalFlip(),
  79. transforms.ToTensor(),
  80. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  81. ]),
  82. 'val': transforms.Compose([
  83. transforms.Resize(256),
  84. transforms.CenterCrop(224),
  85. transforms.ToTensor(),
  86. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  87. ]),
  88. }
  89. # your image data file
  90. data_dir = '/home/syj/Documents/datas/Feret'
  91. image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
  92. data_transforms[x]) for x in ['train', 'val']}
  93. # wrap your data and label into Tensor
  94. dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
  95. batch_size=10,
  96. shuffle=True,
  97. num_workers=10) for x in ['train', 'val']}
  98. dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
  99. # use gpu or not
  100. use_gpu = torch.cuda.is_available()
  101. # get model and replace the original fc layer with your fc layer
  102. # model_ft = models.resnet18(pretrained=True)
  103. # num_ftrs = model_ft.fc.in_features
  104. # model_ft.fc = nn.Linear(num_ftrs, 200)
  105. model_ft = torch.load('/home/syj/Documents/model/resnet18_0.003.pkl')
  106. ##paint
  107. train_loss = []
  108. train_acc = []
  109. test_loss = []
  110. test_acc = []
  111. if use_gpu:
  112. model_ft = model_ft.cuda()
  113. # define loss function
  114. criterion = nn.CrossEntropyLoss()
  115. # Observe that all parameters are being optimized
  116. optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.01, momentum=0.9)
  117. # Decay LR by a factor of 0.1 every 7 epochs
  118. exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
  119. model_ft = train_model(model=model_ft,
  120. criterion=criterion,
  121. optimizer=optimizer_ft,
  122. scheduler=exp_lr_scheduler,
  123. num_epochs=2)
  124. #torch.save(model_ft, '/home/syj/Documents/model/resnet18_0.003.pkl')
  125. '''
  126. ##paint
  127. plt.figure()
  128. plt.subplot(2, 2, 1)
  129. plt.plot(train_loss, lw = 1.5, label = 'train_loss')
  130. plt.subplot(2, 2, 2)
  131. plt.plot(train_acc, lw = 1.5, label = 'train_acc')
  132. plt.subplot(2, 2, 3)
  133. plt.plot(test_loss, lw = 1.5,label = 'loss')
  134. plt.subplot(2, 2, 4)
  135. plt.plot(test_acc, lw = 1.5, label = 'acc')
  136. plt.savefig("resnet18_0.01-10.jpg")
  137. plt.show()
  138. print(dataset_sizes)
  139. '''
  140. '''
  141. https://blog.csdn.net/u014380165/article/details/78525273
  142. ----------
  143. train Loss: 0.1916 Acc: 0.8083
  144. val Loss: 0.0262 Acc: 0.9778
  145. Epoch 24/24
  146. ----------
  147. train Loss: 0.2031 Acc: 0.8250
  148. val Loss: 0.0269 Acc: 1.0000
  149. Training complete in 4m 19s
  150. Best val Acc: 1.000000
  151. '''
  152. ''' lr=0.003
  153. Epoch 9/9
  154. ----------
  155. train Loss: 0.1358 Acc: 0.6710
  156. val Loss: 0.1135 Acc: 0.6575
  157. Training complete in 9m 43s
  158. Best val Acc: 0.657500
  159. '''
  160. ''' lr=0.01 15
  161. Epoch 9/9
  162. ----------
  163. train Loss: 0.0415 Acc: 0.8530
  164. val Loss: 0.0802 Acc: 0.7225
  165. Training complete in 10m 6s
  166. Best val Acc: 0.722500
  167. '''
  168. ''' 0.01 10
  169. Epoch 38/39
  170. ----------
  171. train Loss: 0.0509 Acc: 0.8640
  172. val Loss: 0.1262 Acc: 0.7325
  173. Epoch 39/39
  174. ----------
  175. train Loss: 0.0508 Acc: 0.8520
  176. val Loss: 0.1396 Acc: 0.7200
  177. Training complete in 4m 13s
  178. Best val Acc: 0.737500
  179. '''


声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/729935
推荐阅读
相关标签
  

闽ICP备14008679号