当前位置:   article > 正文

transformers自定义模型的保存和加载_transformers 加载本地model

transformers 加载本地model

step1 保存 (my_plbart.py)

  1. #如果一开始用了并行训练最好加上这句
  2. model_to_save = model.module if hasattr(model, 'module') else model
  3. #这样保存的是模型参数,记得格式是.pt
  4. torch.save(model_to_save.state_dict(),output_model_dir+"model-2.pt")

step2 加载 (use_plbart.py)

  1. #因为是自定义模型呀
  2. model = Model()
  3. #拿到保存的参数
  4. model_static_dict = torch.load(output_model_dir+"model-2.pt")
  5. #把参数加载到模型中
  6. model.load_state_dict(model_static_dict)

注意:

两个文件中的 output_model_dir 路径和Model类应该是一致的。

话外:

如果你的模型不是自定义的,而是直接用的transformers中from_pretrained得到的,那么可以直接用save_pretrained进行保存。以上提供的是更一般化的方法,即torch对模型参数保存和加载的支持。

附上完整的模型文件 only_model.py

  1. import torch
  2. from transformers import PLBartConfig, PLBartModel, PLBartTokenizer
  3. plbart_hf_path = "uclanlp/plbart-multi_task-java"
  4. plbart_local_path = "your_path/plbart_files"
  5. output_model_dir = 'your_path/PLBART_huggingface/finetuned_models/'
  6. checkpoint = plbart_local_path
  7. myTokenizer = PLBartTokenizer.from_pretrained(checkpoint)
  8. class Model(torch.nn.Module):
  9. def __init__(self):
  10. super().__init__()
  11. self.pretrained = PLBartModel.from_pretrained(checkpoint)
  12. # 定义一组值全为0的常量
  13. self.register_buffer(
  14. "final_logits_bias",
  15. torch.zeros(1, myTokenizer.vocab_size)
  16. )
  17. self.fc = torch.nn.Linear(768, myTokenizer.vocab_size, bias=False)
  18. # 加载预训练模型的参数
  19. parameters = PLBartConfig()
  20. # self.fc.load_state_dict(parameters.lm_head.state_dict())
  21. self.criterion = torch.nn.CrossEntropyLoss()
  22. def forward(self, input_ids, attention_mask, labels, decoder_input_ids):
  23. logits = self.pretrained(
  24. input_ids=input_ids,
  25. attention_mask=attention_mask,
  26. decoder_input_ids=decoder_input_ids
  27. )
  28. logits = logits.last_hidden_state
  29. logits = self.fc(logits)+self.final_logits_bias
  30. loss = self.criterion(logits.flatten(end_dim=1), labels.flatten())
  31. return {"loss": loss, "logits": logits}

(only_model.py被其他两个py引用,单拎出来形成一个模型文件的好处是,如果直接用use_plbart.py引用my_plbart.py,还会引用进很多无关的代码,Maybe非常耗时甚至卡住)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/249558
推荐阅读
相关标签
  

闽ICP备14008679号