赞
踩
如果不想额外下载预训练模型,可通过以下代码设置自动下载对应的权重文件(下载速度可能较慢):
- vgg16 = models.vgg16(pretrained=True)
-
- # 替换下面代码:
- # vgg16 = models.vgg16()
- # weights = torch.load('./vgg16-397923af.pth')
- # vgg16.load_state_dict(weights)
- from torchvision import models
- import cv2
- from torchvision import transforms
- import torch
- from PIL import Image
- import numpy as np
-
- # 初始化模型
- device = 0
- vgg16 = models.vgg16().to(device)
- weights = torch.load('./vgg16-397923af.pth')
- vgg16.load_state_dict(weights)
-
- # 前处理
- normalize = transforms.Normalize(mean = [0.485, 0.456, 0.406], # imagenet dataset的均值
- std = [0.229, 0.224, 0.225])
-
- tran = transforms.Compose([transforms.Resize((224,224)),
- transforms.ToTensor(),
- transforms.Normalize(mean = [0.485, 0.456, 0.406],
- std = [0.229, 0.224, 0.225])])
-
- if __name__ == "__main__":
- # 读取图片
- img = cv2.imread('./test1.jpg')
-
- # 前处理
- img = Image.fromarray(np.uint8(img)).convert('RGB') # [3, 224, 224]
- img = tran(img)
- img.unsqueeze_(dim=0) # [1, 3, 224, 224]
-
- # 推理
- output = vgg16(img.to(device)) # [1, 1000]
-
- # 后处理
- output = output.data[0] # [1000]
- output = output.cpu().detach().numpy()
- print(output.shape)

Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。