当前位置:   article > 正文

深度学习笔记--使用VGG16预训练模型_vgg16预训练权重

vgg16预训练权重

1--常用预训练模型下载

        参考常用预训练模型下载地址

        如果不想额外下载预训练模型,可通过以下代码设置自动下载对应的权重文件(下载速度可能较慢):

  1. vgg16 = models.vgg16(pretrained=True)
  2. # 替换下面代码:
  3. # vgg16 = models.vgg16()
  4. # weights = torch.load('./vgg16-397923af.pth')
  5. # vgg16.load_state_dict(weights)

2--使用VGG16预训练模型

  1. from torchvision import models
  2. import cv2
  3. from torchvision import transforms
  4. import torch
  5. from PIL import Image
  6. import numpy as np
  7. # 初始化模型
  8. device = 0
  9. vgg16 = models.vgg16().to(device)
  10. weights = torch.load('./vgg16-397923af.pth')
  11. vgg16.load_state_dict(weights)
  12. # 前处理
  13. normalize = transforms.Normalize(mean = [0.485, 0.456, 0.406], # imagenet dataset的均值
  14. std = [0.229, 0.224, 0.225])
  15. tran = transforms.Compose([transforms.Resize((224,224)),
  16. transforms.ToTensor(),
  17. transforms.Normalize(mean = [0.485, 0.456, 0.406],
  18. std = [0.229, 0.224, 0.225])])
  19. if __name__ == "__main__":
  20. # 读取图片
  21. img = cv2.imread('./test1.jpg')
  22. # 前处理
  23. img = Image.fromarray(np.uint8(img)).convert('RGB') # [3, 224, 224]
  24. img = tran(img)
  25. img.unsqueeze_(dim=0) # [1, 3, 224, 224]
  26. # 推理
  27. output = vgg16(img.to(device)) # [1, 1000]
  28. # 后处理
  29. output = output.data[0] # [1000]
  30. output = output.cpu().detach().numpy()
  31. print(output.shape)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/351072
推荐阅读
相关标签
  

闽ICP备14008679号