赞
踩
- class Seq2SeqTrainerForChatGLM(PeftTrainer):
- r"""
- Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
- """
-
- def save_predictions(
- self,
- predict_results: PredictionOutput,
- tokenizer: PreTrainedTokenizer
- ) -> None:
- r"""
- Saves model predictions to `output_dir`.
- A custom behavior that not contained in Seq2SeqTrainer.
- """
- if not self.is_world_process_zero():
- return
-
- preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
- labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
-
- preds = preds[:, labels.shape[1]:] # remove prompts
- preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds]
- labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels]
-
- output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
- logger.info(f"Saving prediction results to {output_prediction_file}")
- with open(output_prediction_file, "w", encoding="utf-8") as writer:
- res: List[str] = []
- for pred, label in zip(preds, labels):
- res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
- writer.write("\n".join(res))

这是一个Python代码段,定义了一个名为Seq2SeqTrainerForChatGLM的类,该类继承自PeftTrainer。这个类特定用于处理序列到序列模型的训练,以及生成指标(如BLEU和ROUGE)的计算。
接下来,我们将逐行解释这段代码:
class Seq2SeqTrainerForChatGLM(PeftTrainer): - 这行定义了一个新的类,Seq2SeqTrainerForChatGLM,它继承自PeftTrainer类。
def save_predictions(self, predict_results: PredictionOutput, tokenizer: PreTrainedTokenizer) -> None: - 这是一个save_predictions方法,它接受两个参数,一个是predict_results(预测结果,类型为PredictionOutput),另一个是tokenizer(预训练的分词器,类型为PreTrainedTokenizer)。
if not self.is_world_process_zero(): return - 这行检查当前进程是否为主进程(进程编号为0)。如果不是,这个方法就结束运行。
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id) - 这行创建了一个新的预测值数组,其中非忽略索引的预测值被保留,忽略索引的预测值被替换为填充符号的ID。
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id) - 这行创建了一个新的标签数组,其中非忽略索引的标签值被保留,忽略索引的标签值被替换为填充符号的ID。
preds = preds[:, labels.shape[1]:] - 这行移除了预测值中的提示,提示是在预测数组的开始部分。
preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds] - 这行使用分词器解码预测值,移除特殊符号,并删除前后的空白。
labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels] - 这行使用分词器解码标签,移除特殊符号,并删除前后的空白。
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") - 这行定义了预测结果的输出文件路径。
logger.info(f"Saving prediction results to {output_prediction_file}") - 这行向日志发送一条信息,表明预测结果将被保存到哪个文件。
with open(output_prediction_file, "w", encoding="utf-8") as writer: - 这行打开预测结果输出文件,以写入模式,准备写入预测结果。
res: List[str] = [] - 这行初始化一个空的列表,准备收集每一行的预测结果。
for pred, label in zip(preds, labels): - 这行开始一个循环,遍历所有的预测值和标签。
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) - 在循环中,这行将每对预测和标签以JSON格式转换为字符串,然后添加到结果列表中。
writer.write("\n".join(res)) - 这行将结果列表连接成一个字符串,每个结果之间用换行符隔开,然后写入到文件中。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。