赞
踩
- #保存模型
- torch.save(model_object,'resnet.pth')
- #加载模型
- model=torch.load('resnet.pth')
-
- torch.save(my_resnet.state_dict(),"resnet.pth")
-
- resnet_model.load_state_dict(torch.load("resnet.pth"))
-
- resnet152=models.resnet152(pretrained=True)
- pretrained_dict=resnet152.state_dict()
-
- model_dict=model.state_dict()
-
- #将pretrained_dict里不属于model_dict的键去除掉
- pretrained_dict={k:v for k,v in pretrained_dict.items() if k in model_dict}
-
- #更新现有的model_dict
- model_dict.update(pretrained_dict)
- #加载真正需要的state_dict
- model.load_state_dict(model_dict)
-
- resnet=torchvision.models.resnet152(pretrained=True)
-
- resnet.fc=torch.nn.Linear(2048,10)
-
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。