当前位置:   article > 正文

(十六)完整的模型验证套路_模型resize操作

模型resize操作

【声明】来源b站视频小土堆PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】_哔哩哔哩_bilibili

测试/demo

套路:利用及已经训练好的模型,给它提供输入

模型是在谷歌的gpu上训练好的,精度为60多,这边直接加载

 test.py

  1. import torch
  2. from PIL import Image
  3. import torchvision
  4. import torch.nn as nn
  5. img_path = "./data/cat.png" #图片路径
  6. image = Image.open(img_path) #读取图片,把图片转换为PIL类型
  7. #png格式是4通道,RGB+透明通道
  8. image = image.convert('RGB') #保留颜色通道,如果图片本来就是三个通道,经过此操作,不变,加上这一步,可以适应png,jpg各种格式的图片
  9. print(image)
  10. transform = torchvision.transforms.Compose([
  11. torchvision.transforms.Resize((32,32)), #把图片进行resize
  12. torchvision.transforms.ToTensor() #转换为Tensor
  13. ])
  14. image = transform(image)
  15. print(image.shape)
  16. # 搭建神经网络
  17. class Model(nn.Module):
  18. def __init__(self) -> None:
  19. super().__init__()
  20. self.model = nn.Sequential(
  21. nn.Conv2d(3, 32, 5, 1, 2),
  22. nn.MaxPool2d(2),
  23. nn.Conv2d(32, 32, 5, 1, 2),
  24. nn.MaxPool2d(2),
  25. nn.Conv2d(32, 64, 5, 1, 2),
  26. nn.MaxPool2d(2),
  27. nn.Flatten(),
  28. nn.Linear(64*4*4, 64),
  29. nn.Linear(64, 10)
  30. )
  31. def forward(self, input):
  32. input = self.model(input)
  33. return input
  34. model = torch.load("./data/mymodel_train_goole29.pth",map_location=torch.device('cpu'))
  35. print(model)
  36. image = torch.reshape(image,(1,3,32,32))
  37. model.eval() #把模型转换为测试类型
  38. with torch.no_grad(): #没有梯度
  39. output =model(image)
  40. print(output) # tensor([[-1.4594, 0.5749, 0.6508, 0.7668, 0.8262, 0.8882, 1.1195, 0.5799,
  41. # -1.7798, -0.3695]])
  42. print(output.argmax(1)) #预测类别 tensor([6])

结果

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/51931
推荐阅读
相关标签
  

闽ICP备14008679号