当前位置:   article > 正文

Pytorch:加载断点(pth)权重参数

Pytorch:加载断点(pth)权重参数

一、保存的模型参数及权重

  1. #保存模型
  2. torch.save(model_object,'resnet.pth')
  3. #加载模型
  4. model=torch.load('resnet.pth')

二、仅保存模型的权重

  1. torch.save(my_resnet.state_dict(),"resnet.pth")
  2. resnet_model.load_state_dict(torch.load("resnet.pth"))

三、仅加载部分参数

  1. resnet152=models.resnet152(pretrained=True)
  2. pretrained_dict=resnet152.state_dict()
  3. model_dict=model.state_dict()
  4. #将pretrained_dict里不属于model_dict的键去除掉
  5. pretrained_dict={k:v for k,v in pretrained_dict.items() if k in model_dict}
  6. #更新现有的model_dict
  7. model_dict.update(pretrained_dict)
  8. #加载真正需要的state_dict
  9. model.load_state_dict(model_dict)

四、微调预训练模型 

  1. resnet=torchvision.models.resnet152(pretrained=True)
  2. resnet.fc=torch.nn.Linear(2048,10)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小惠珠哦/article/detail/998525
推荐阅读
相关标签
  

闽ICP备14008679号