当前位置:   article > 正文

pytorch模型载入之gpu和cpu互转_torch.load gpu模型转为cpu模型

torch.load gpu模型转为cpu模型

Pytorch训练模型fine-tunning、模型推理等环节常常涉及到模型加载,其中会涉及到将不同平台、版本的模型相互转化:

Case-1.载入多GPU模型

  1. pretained_model = torch.load(’muti_gpus_model.pth‘) # 网络+权重
  2. # 载入为单gpu模型
  3. gpu_model = pretrained_model.module # GPU-version
  4. # 载入为cpu模型
  5. model = ModelArch()
  6. pretained_dict = pretained_model.module.state_dict()
  7. model.load_state_dict(pretained_dict) # CPU-version

Case-2.载入多GPU权重

  1. model = ModelArch(para).cuda(0) # 网络结构
  2. model = torch.nn.DataParallel(model, device_ids=[0]) # 将model转为muit-gpus模式
  3. checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) # 载入weights
  4. model.load_state_dict(checkpoint) # 用weights初始化网络
  5. # 载入为单gpu模型
  6. gpu_model = model.module # GPU-version
  7. # 载入为cpu模型
  8. model = ModelArch(para)
  9. model.load_state_dict(gpu_model.state_dict())
  10. torch.save(cpu_model.state_dict(), 'cpu_mode.pth') # cpu模型存储, 注意这里的state_dict后的()必须加上,否则报'function' object has no attribute 'copy'错误

Case-3.载入CPU权重 | [inference]

  1. # 载入为cpu版本
  2. model = ModelArch(para)
  3. checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) # 载入weights
  4. # 载入为gpu版本
  5. model = ModelArch(para).cuda() # 网络结构
  6. checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage.cuda(0)) # 载入weights
  7. model.load_state_dict(checkpoint) # 用weights初始化网络
  8. # 载入为muti-gpus版本
  9. model = ModelArch(para).cuda() # 网络结构
  10. model = torch.nn.DataParallel(model, device_ids=[0, 1]) # device_ids根据自己需求改!
  11. checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage.cuda(0)) # 载入weights
  12. model.module.load_state_dict(checkpoint) # 用weights初始化网络

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/酷酷是懒虫/article/detail/851509?site
推荐阅读
相关标签
  

闽ICP备14008679号