当前位置:   article > 正文

pytorch实现resnet50(训练+测试+模型转换)_pytorch resnet50预训练模型

pytorch resnet50预训练模型

本章使用pytorch训练resnet50,使用cifar数据集。

数据集:

代码工程:

1.train.py

  1. import torch
  2. from torch import nn, optim
  3. import torchvision.transforms as transforms
  4. from torchvision import datasets
  5. from torch.utils.data import DataLoader
  6. from resnet50 import ResNet50
  7. # 用CIFAR-10 数据集进行实验
  8. def main():
  9. batchsz = 2
  10. cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
  11. transforms.Resize((224, 224)),
  12. transforms.ToTensor(),
  13. transforms.Normalize(mean=[0.485, 0.456, 0.406],
  14. std=[0.229, 0.224, 0.225])
  15. ]), download=True)
  16. cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
  17. cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
  18. transforms.Resize((224, 224)),
  19. transforms.ToTensor(),
  20. transforms.Normalize(mean=[0.485, 0.456, 0.406],
  21. std=[0.229, 0.224, 0.225])
  22. ]), download=True)
  23. cifar_test = DataLoader(cifar_test, batch_size=1, shuffle=True)
  24. x, label = iter(cifar_train).next()
  25. print('x:', x.shape, 'label:', label.shape)
  26. device = torch.device('cpu')
  27. model = ResNet50().to(device)
  28. print(*list(model.children())[-3:-2])
  29. criteon = nn.CrossEntropyLoss().to(device)
  30. optimizer = optim.Adam(model.parameters(), lr=1e-3)
  31. # print(model)
  32. print(iter(cifar_test).next()[0].shape)
  33. for epoch in range(1):
  34. model.train()
  35. for batchidx, (x, label) in enumerate(cifar_train):
  36. if batchidx<=2:
  37. x, label = x.to(device), label.to(device)
  38. logits = model(x)
  39. loss = criteon(logits, label)
  40. # backprop
  41. optimizer.zero_grad()
  42. loss.backward()
  43. optimizer.step()
  44. print("epoch:",epoch, "index:",batchidx)
  45. else:
  46. continue
  47. print(epoch, 'loss:', loss.item())
  48. # # # PATH="model/test.pth"
  49. torch.save(model, "model2/test.pth")
  50. torch.save(model.state_dict(),"model2/test2.pth")
  51. model.eval()
  52. with torch.no_grad():
  53. # test
  54. total_correct = 0
  55. total_num = 0
  56. for idx, (x, label) in enumerate(cifar_test):
  57. if idx<=5:
  58. x, label = x.to(device), label.to(device)
  59. logits = model(x)
  60. pred = logits.argmax(dim=1)
  61. correct = torch.eq(pred, label).float().sum().item()
  62. total_correct += correct
  63. total_num += x.size(0)
  64. # print(pred)
  65. acc = total_correct / total_num
  66. print(epoch, 'test acc:', acc)
  67. if __name__ == '__main__':
  68. main()
  69. # # 保存整个网络
  70. # torch.save(net, PATH)
  71. # # 保存网络中的参数, 速度快,占空间少
  72. # torch.save(net.state_dict(),PATH)
  73. # #--------------------------------------------------
  74. # #针对上面一般的保存方法,加载的方法分别是:
  75. # model_dict=torch.load(PATH)
  76. # model_dict=model.load_state_dict(torch.load(PATH))

2.test_pth.py

  1. from resnet50 import ResNet50
  2. import torch
  3. from PIL import Image
  4. from torchvision import transforms
  5. import cv2
  6. import numpy as np
  7. def prediect(img_path):
  8. device = torch.device('cpu')
  9. model=torch.load("model2/test.pth")
  10. model=model.to(device)
  11. # model = ResNet50()
  12. # weight=torch.load("model/test2.pth")
  13. # model.load_state_dict(weight)
  14. # model=model.to(device)
  15. img=cv2.imread(img_path)
  16. img=cv2.resize(img, (224, 224))
  17. img=np.reshape(img,(1,224,224,3))
  18. img=img.transpose(0,3,1,2).copy()
  19. print(img.shape)
  20. img_ = torch.Tensor(img)
  21. torch.no_grad()
  22. outputs = model(img_)
  23. _, predicted = torch.max(outputs, 1)
  24. print('pred :',outputs, predicted)
  25. if __name__ == '__main__':
  26. img_path="img/dog2.jpg"
  27. prediect(img_path)

3.resnet50.py

  1. import torch
  2. import torch.nn as nn
  3. from torch.nn import functional as F
  4. class ResNet50BasicBlock(nn.Module):
  5. def __init__(self, in_channel, outs, kernerl_size, stride, padding):
  6. super(ResNet50BasicBlock, self).__init__()
  7. self.conv1 = nn.Conv2d(in_channel, outs[0], kernel_size=kernerl_size[0], stride=stride[0], padding=padding[0])
  8. self.bn1 = nn.BatchNorm2d(outs[0])
  9. self.conv2 = nn.Conv2d(outs[0], outs[1], kernel_size=kernerl_size[1], stride=stride[0], padding=padding[1])
  10. self.bn2 = nn.BatchNorm2d(outs[1])
  11. self.conv3 = nn.Conv2d(outs[1], outs[2], kernel_size=kernerl_size[2], stride=stride[0], padding=padding[2])
  12. self.bn3 = nn.BatchNorm2d(outs[2])
  13. def forward(self, x):
  14. out = self.conv1(x)
  15. out = F.relu(self.bn1(out))
  16. out = self.conv2(out)
  17. out = F.relu(self.bn2(out))
  18. out = self.conv3(out)
  19. out = self.bn3(out)
  20. return F.relu(out + x)
  21. class ResNet50DownBlock(nn.Module):
  22. def __init__(self, in_channel, outs, kernel_size, stride, padding):
  23. super(ResNet50DownBlock, self).__init__()
  24. # out1, out2, out3 = outs
  25. # print(outs)
  26. self.conv1 = nn.Conv2d(in_channel, outs[0], kernel_size=kernel_size[0], stride=stride[0], padding=padding[0])
  27. self.bn1 = nn.BatchNorm2d(outs[0])
  28. self.conv2 = nn.Conv2d(outs[0], outs[1], kernel_size=kernel_size[1], stride=stride[1], padding=padding[1])
  29. self.bn2 = nn.BatchNorm2d(outs[1])
  30. self.conv3 = nn.Conv2d(outs[1], outs[2], kernel_size=kernel_size[2], stride=stride[2], padding=padding[2])
  31. self.bn3 = nn.BatchNorm2d(outs[2])
  32. self.extra = nn.Sequential(
  33. nn.Conv2d(in_channel, outs[2], kernel_size=1, stride=stride[3], padding=0),
  34. nn.BatchNorm2d(outs[2])
  35. )
  36. def forward(self, x):
  37. x_shortcut = self.extra(x)
  38. out = self.conv1(x)
  39. out = self.bn1(out)
  40. out = F.relu(out)
  41. out = self.conv2(out)
  42. out = self.bn2(out)
  43. out = F.relu(out)
  44. out = self.conv3(out)
  45. out = self.bn3(out)
  46. return F.relu(x_shortcut + out)
  47. class ResNet50(nn.Module):
  48. def __init__(self):
  49. super(ResNet50, self).__init__()
  50. self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
  51. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  52. self.layer1 = nn.Sequential(
  53. ResNet50DownBlock(64, outs=[64, 64, 256], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
  54. ResNet50BasicBlock(256, outs=[64, 64, 256], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
  55. ResNet50BasicBlock(256, outs=[64, 64, 256], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
  56. )
  57. self.layer2 = nn.Sequential(
  58. ResNet50DownBlock(256, outs=[128, 128, 512], kernel_size=[1, 3, 1], stride=[1, 2, 1, 2], padding=[0, 1, 0]),
  59. ResNet50BasicBlock(512, outs=[128, 128, 512], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
  60. ResNet50BasicBlock(512, outs=[128, 128, 512], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
  61. ResNet50DownBlock(512, outs=[128, 128, 512], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0])
  62. )
  63. self.layer3 = nn.Sequential(
  64. ResNet50DownBlock(512, outs=[256, 256, 1024], kernel_size=[1, 3, 1], stride=[1, 2, 1, 2], padding=[0, 1, 0]),
  65. ResNet50BasicBlock(1024, outs=[256, 256, 1024], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1],
  66. padding=[0, 1, 0]),
  67. ResNet50BasicBlock(1024, outs=[256, 256, 1024], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1],
  68. padding=[0, 1, 0]),
  69. ResNet50DownBlock(1024, outs=[256, 256, 1024], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1],
  70. padding=[0, 1, 0]),
  71. ResNet50DownBlock(1024, outs=[256, 256, 1024], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1],
  72. padding=[0, 1, 0]),
  73. ResNet50DownBlock(1024, outs=[256, 256, 1024], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1],
  74. padding=[0, 1, 0])
  75. )
  76. self.layer4 = nn.Sequential(
  77. ResNet50DownBlock(1024, outs=[512, 512, 2048], kernel_size=[1, 3, 1], stride=[1, 2, 1, 2],
  78. padding=[0, 1, 0]),
  79. ResNet50DownBlock(2048, outs=[512, 512, 2048], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1],
  80. padding=[0, 1, 0]),
  81. ResNet50DownBlock(2048, outs=[512, 512, 2048], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1],
  82. padding=[0, 1, 0])
  83. )
  84. self.avgpool = nn.AvgPool2d(kernel_size = 7,stride=1,ceil_mode=False)
  85. # self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
  86. self.fc = nn.Linear(2048, 10)
  87. # 使用卷积代替全连接
  88. self.conv11=nn.Conv2d(2048, 10, kernel_size=1, stride=1, padding=0)
  89. def forward(self, x):
  90. out = self.conv1(x)
  91. out = self.maxpool(out)
  92. out = self.layer1(out)
  93. out = self.layer2(out)
  94. out = self.layer3(out)
  95. out = self.layer4(out)
  96. out = self.avgpool(out)
  97. out=self.conv11(out)
  98. out = out.reshape(x.shape[0], -1)
  99. # out = self.fc(out)
  100. return out
  101. if __name__ == '__main__':
  102. x = torch.randn(1, 3, 224, 224)
  103. net = ResNet50()
  104. out = net(x)
  105. print('out.shape: ', out.shape)
  106. print(out)

4.pth2onnx.py

  1. import torch
  2. from torchsummary import summary
  3. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  4. model = torch.load("model2/test.pth") # pytorch模型加载
  5. model.eval()
  6. for name in model.state_dict():
  7. print(name)
  8. summary(model, (3, 224, 224))
  9. input_shape=list(map(int, "1,3,224,224".split(",")))
  10. x = torch.randn(input_shape) # 生成张量
  11. x = x.to(device)
  12. export_onnx_file = "model2/test.onnx" # 目的ONNX文件名
  13. torch.onnx.export(model, x, export_onnx_file, verbose=True)
  14. # torch.onnx.export(model, x, export_onnx_file, verbose=True, export_params=True, do_constant_folding=True, opset_version=11)
  15. # input_names=['boxes']
  16. # output_names=['layer1.1.conv1.bias']
  17. # torch.onnx.export(model, x, export_onnx_file,
  18. # export_params=True,
  19. # do_constant_folding=True,
  20. # input_names=input_names,
  21. # output_names=output_names
  22. # )

5.test_onnx_v1.py

  1. import cv2
  2. import numpy as np
  3. import onnxruntime as rt
  4. def image_process(image_path):
  5. mean = np.array([[[0.485, 0.456, 0.406]]]) # 训练的时候用来mean和std
  6. std = np.array([[[0.229, 0.224, 0.225]]])
  7. img = cv2.imread(image_path)
  8. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  9. img = cv2.resize(img, (224, 224)) # (96, 96, 3)
  10. image = img.astype(np.float32)/255.0
  11. image = (image - mean)/ std
  12. image = image.transpose((2, 0, 1)) # (3, 96, 96)
  13. image = image[np.newaxis,:,:,:] # (1, 3, 96, 96)
  14. image = np.array(image, dtype=np.float32)
  15. return image
  16. def onnx_runtime():
  17. imgdata = image_process('img/test.jpg')
  18. sess = rt.InferenceSession("model2/test.onnx")
  19. input_name = sess.get_inputs()[0].name
  20. output_name = sess.get_outputs()[0].name
  21. pred_onnx = sess.run([output_name], {input_name: imgdata})
  22. print("outputs:",np.array(pred_onnx)[0].shape)
  23. onnx_runtime()

6.test_onnx_v2.py

  1. import numpy as np
  2. import torch
  3. import onnx
  4. import onnxruntime
  5. import pickle
  6. # 测试数据
  7. x = torch.randn(1,3,224,224, requires_grad=False)
  8. print(type(x),x.shape)
  9. # 使用 ONNX 的 API 检查 ONNX 模型
  10. onnx_model = onnx.load("model2/test.onnx")
  11. onnx.checker.check_model(onnx_model)
  12. # onnx模型测试
  13. ort_session = onnxruntime.InferenceSession("model2/test.onnx")
  14. def to_numpy(tensor):
  15. return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
  16. #结果输出
  17. ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
  18. ort_outs = ort_session.run(None, ort_inputs)
  19. ort_out = ort_outs[0]
  20. print(x.shape, ort_out.shape)
  21. # torch模型测试
  22. # model=torch.load("test/person_reid.pth",map_location='cpu')
  23. # model.eval()
  24. # torch_out = model(x)
  25. # 比较ONNX 和 PyTorch 的结果
  26. # np.testing.assert_allclose(to_numpy(torch_out), ort_out, rtol=1e-03, atol=1e-05)
  27. # print("模型没有太大差异!")

7.onnx2pb.py

  1. import onnx
  2. from onnx_tf.backend import prepare
  3. def onnx2pb(onnx_input_path, pb_output_path):
  4. onnx_model = onnx.load(onnx_input_path) # load onnx model
  5. tf_exp = prepare(onnx_model) # prepare tf representation
  6. tf_exp.export_graph(pb_output_path) # export the model
  7. if __name__ == "__main__":
  8. # onnx_input_path = 'test/person_reid.onnx'
  9. # pb_output_path = 'test/person_reid2.pb'
  10. onnx_input_path = 'model2/test.onnx'
  11. pb_output_path = 'model2/test.pb'
  12. onnx2pb(onnx_input_path, pb_output_path)

8.test_pb.py  (onnx+pb)

  1. import tensorflow as tf
  2. from tensorflow.python.framework import graph_util
  3. from tensorflow.python import pywrap_tensorflow
  4. import cv2
  5. import numpy as np
  6. import torch
  7. import onnx
  8. import onnxruntime
  9. import pickle
  10. def recognize(img, pb_file_path):
  11. with tf.Graph().as_default():
  12. output_graph_def = tf.GraphDef()
  13. with open(pb_file_path, "rb") as f:#主要步骤即为以下标出的几步,1、2步即为读取图
  14. output_graph_def.ParseFromString(f.read())# 1.将模型文件解析为二进制放进graph_def对象
  15. _ = tf.import_graph_def(output_graph_def, name="")# 2.import到当前图
  16. with tf.Session() as sess:
  17. init = tf.global_variables_initializer()
  18. sess.run(init)
  19. graph = tf.get_default_graph()# 3.获得当前图
  20. # # 4.get_tensor_by_name获取需要的节点
  21. # x = graph.get_tensor_by_name("IteratorGetNext_1:0")
  22. # y_out = graph.get_tensor_by_name("resnet_v1_50_1/predictions/Softmax:0")
  23. x = graph.get_tensor_by_name("data:0")
  24. y_out = graph.get_tensor_by_name("reid_embedding:0")
  25. # img=np.random.normal(size=(1, 224, 224, 3))
  26. # img=cv2.imread(jpg_path)
  27. # img=cv2.resize(img, (128, 256))
  28. # img=np.reshape(img,(1,128,256,3))
  29. # img=img.transpose(0,3,1,2).copy()
  30. # print(img.shape)
  31. #执行
  32. output = sess.run(y_out, feed_dict={x:img})
  33. pred=np.argmax(output, axis=1)
  34. return output
  35. # print("预测结果:", output.shape, output, "预测label:", pred)
  36. jpg_path="img/test.jpg"
  37. img=cv2.imread(jpg_path)
  38. img=cv2.resize(img, (128, 256))
  39. img=np.reshape(img,(1,128,256,3))
  40. img=img.transpose(0,3,1,2).copy()
  41. print(img.shape)
  42. x = torch.randn(1,3,256,128, requires_grad=False)
  43. img=x
  44. # 测试pb
  45. a=recognize(img, "test/gg.pb")
  46. print(a.shape)
  47. # b=recognize(img, "test/person_reid2.pb")
  48. # np.testing.assert_allclose(a, b, rtol=1e-03, atol=1e-05)
  49. # print(a.shape,a[0][4],b[0][4])
  50. # # # 测试数据
  51. # # x = torch.randn(1,3,256,128, requires_grad=False)
  52. # # # x=torch.from_numpy(img)
  53. # # # x.requires_grad=False
  54. # # 使用 ONNX 的 API 检查 ONNX 模型
  55. # onnx_model = onnx.load("test/person_reid.onnx")
  56. # onnx.checker.check_model(onnx_model)
  57. # # onnx模型测试
  58. # ort_session = onnxruntime.InferenceSession("test/person_reid.onnx")
  59. # def to_numpy(tensor):
  60. # return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
  61. # #结果输出
  62. # ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
  63. # ort_outs = ort_session.run(None, ort_inputs)
  64. # ort_out = ort_outs[0]
  65. # print(ort_out.shape, ort_out[0][4])
  66. # np.testing.assert_allclose(a, ort_out, rtol=1e-03, atol=1e-05)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/371334
推荐阅读
相关标签
  

闽ICP备14008679号