当前位置:   article > 正文

T5模型在训练过程中实时计算验证集上准确度,自定义compute_metrics_计算验证集准确率

计算验证集准确率

T5模型不同于BERT类模型,它是一个seq2seq模型,在训练过程中预测结果实时返回的是字典长度的置信度。

将T5用于解决NLU问题时,想要在训练过程中实时监测在验证集上的准确度,也很简单,只需要添加自定义compute_metrics函数。

以下为采用transformers框架训练添加自定义compute_metrics函数的代码:

  1. def compute_accuracy(pred):
  2. ## 1.处理 pred.predictions
  3. # 每个样本的预测结果为vocab大小
  4. predict_res = torch.Tensor(pred.predictions[0]) # size:[验证集样本量, label的token长度, vocab大小]
  5. pred_ids = predict_res.argmax(dim=2)
  6. ## 2.处理 pred.label_ids
  7. labels_actual = torch.LongTensor(pred.label_ids)
  8. ## 3.计算accuracy
  9. total_num = labels_actual.shape[0]
  10. acc = torch.sum(torch.all(torch.eq(pred_ids, labels_actual), dim=1))/total_num
  11. return {'accuracy': acc}
  12. trainer = Trainer(
  13. model=model,
  14. args=training_args,
  15. train_dataset=train_dataset,
  16. eval_dataset=eval_dataset,
  17. compute_metrics=compute_accuracy # 添加自定义compute_metrics
  18. )

推荐使用wandb监控训练状态,实时可见此自定义accuracy

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/348514
推荐阅读
相关标签
  

闽ICP备14008679号