当前位置:   article > 正文

attention_mask,pad_token_id报错_the attention mask and the pad token id were not s

the attention mask and the pad token id were not set. as a consequence, you
	device="cuda"
    model_inputs = tokenizer([text], return_tensors="pt").to(device)
    input_ids = tokenizer.encode(text,return_tensors='pt')
    # Generate the response
    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=512
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    # Decode the response
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

在推理LLM时遇到The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's attention_mask to obtain reliable results. Setting pad_token_id to eos_token_id :151643 for open-end generation.

修改代码:

    device="cuda"
    model_inputs = tokenizer([text], return_tensors="pt").to(device)
    input_ids = tokenizer.encode(text,return_tensors='pt')
    attention_mask = torch.ones(input_ids.shape,dtype=torch.long,device=device)
    # Generate the response
    generated_ids = model.generate(
        model_inputs.input_ids,
        attention_mask=attention_mask,
        max_new_tokens=512,
        pad_token_id=tokenizer.eos_token_id   # Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    # Decode the response
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

我们在model.generate中添加attention_mask,并且复制tokenizer.eos_token_id给pad_token_id。解决问题

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号