当前位置:   article > 正文

transformers.generator_utils函数源码解析之sample生成(包括temperature、TopK、TopP函数解析)_temperaturelogitswarper

temperaturelogitswarper

        sample函数相较于beam_search函数要简单的多,但是需要注意的一点是,sample需要搭配logits_warper处理器列表使用,相应三类处理器函数解析在下面。sample函数的源码解释如下,比较浅显易懂。

  1. # auto-regressive generation
  2. while True:
  3. # prepare model inputs
  4. model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
  5. # forward pass to get next token
  6. outputs = self(
  7. **model_inputs,
  8. return_dict=True,
  9. output_attentions=output_attentions,
  10. output_hidden_states=output_hidden_states,
  11. )
  12. if synced_gpus and this_peer_finished:
  13. cur_len = cur_len + 1
  14. continue # don't waste resources running the code we don't need
  15. # 获取cur-step的预测输出
  16. next_token_logits = outputs.logits[:, -1, :]
  17. # pre-process distribution
  18. next_token_scores = logits_processor(input_ids, next_token_logits)
  19. next_token_scores = logits_warper(input_ids, next_token_scores)
  20. # Store scores, attentions and hidden_states when required
  21. if return_dict_in_generate:
  22. if output_scores:
  23. scores += (next_token_scores,)
  24. if output_attentions:
  25. decoder_attentions += (
  26. (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
  27. )
  28. if self.config.is_encoder_decoder:
  29. cross_attentions += (outputs.cross_attentions,)
  30. if output_hidden_states:
  31. decoder_hidden_states += (
  32. (outputs.decoder_hidden_states,)
  33. if self.config.is_encoder_decoder
  34. else (outputs.hidden_states,)
  35. )
  36. # sample
  37. # 对重构后的预测分布进行softmax
  38. probs = F.softmax(next_token_scores, dim=-1)
  39. # 多项式函数取样
  40. next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
  41. # finished sentences should have their next token be a padding token
  42. if eos_token_id is not None:
  43. assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
  44. next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
  45. # update generated ids, model inputs, and length for next step
  46. input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
  47. model_kwargs = self._update_model_kwargs_for_generation(
  48. outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
  49. )
  50. cur_len = cur_len + 1
  51. # if eos_token was found in one sentence, set sentence to finished
  52. if eos_token_id is not None:
  53. unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
  54. # stop when each sentence is finished, or if we exceed the maximum length
  55. if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
  56. if not synced_gpus:
  57. break
  58. else:
  59. this_peer_finished = True
  60. return input_ids

        transfomrer库中的logits_warper系列处理函数包括TemperatureLogitsWarper、TopKLogitsWarper和TopPLogitsWarper。

  • TemperatureLogitsWarper函数:对生成的预测分布概率进行重构,主要是将分布分数除于Temperature超参。在大多数研究中, tempreature的选择,往往呈现如下规律:
    • 实际应用中,往往experiment with multiple temperature values! 当保持了一定的随机性又能不破坏结构时,往往会得到有意思的生成文本。
    • 当 设置高 temperature时,文本局部结构往往会被破坏,大多数词可能会时semi-random strings 的形式。
    • 当temperatures较大时, 生成的文本更具有随机性( random)、趣味性( interesting),甚至创造性( creative); 甚至有些时候能发现一些新词(misspelled words) 。
    • 当temperature较小时,会引发极大的 repetitive 和predictable文本,但是文本内容往往更贴合语料(highly realistic),基本所有的词都来自与语料库。
    • 当 temperature 设置为较小或者0的值时, Temperature Sampling 等同于 每次选择最大概率的 Greedy Search。
  1. class TemperatureLogitsWarper(LogitsWarper):
  2. r"""
  3. :class:`transformers.LogitsWarper` for temperature (exponential scaling output probability distribution).
  4. Args:
  5. temperature (:obj:`float`):
  6. The value used to module the logits distribution.
  7. """
  8. def __init__(self, temperature: float):
  9. if not isinstance(temperature, float) or not (temperature > 0):
  10. raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
  11. self.temperature = temperature
  12. def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
  13. scores = scores / self.temperature
  14. return scores
  • TopKLogitsWarper函数:对模型cur-step预测的概率分布进行排序,保留前topK个分布。
    • 优点:基本top k的采样方法,能够提升生成质量,因为它会把概率较低的结果丢弃( removing the tail),因此能使得生成过程不那么偏离主题。
    • 缺点:丢弃掉的部分(Tail)可能会包含很多的词语,这导致我们能选择的词汇较少。而在一些情况下,丢弃掉大部分可能包含的词汇较少,我们能生成较为丰富的文本。

    • 因此, k 值的选择对于生成结果极其重要。

  1. class TopKLogitsWarper(LogitsWarper):
  2. r"""
  3. :class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements.
  4. Args:
  5. top_k (:obj:`int`):
  6. The number of highest probability vocabulary tokens to keep for top-k-filtering.
  7. filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
  8. All filtered values will be set to this float value.
  9. min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
  10. Minimum number of tokens that cannot be filtered.
  11. """
  12. def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
  13. if not isinstance(top_k, int) or top_k <= 0:
  14. raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
  15. self.top_k = top_k
  16. self.filter_value = filter_value
  17. self.min_tokens_to_keep = min_tokens_to_keep
  18. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  19. top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check
  20. # Remove all tokens with a probability less than the last token of the top-k
  21. #torch.topk(scores, top_k)[0][..., -1]获得top_k排序后最后一个值,即把分布中小于该值的位置设置为true
  22. indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
  23. #将indices_to_remove中为true的值替换为self.filter_value(一般为-inf)
  24. scores = scores.masked_fill(indices_to_remove, self.filter_value)
  25. return scores
  • TopPLogitsWarper函数:采用累计概率的形式,保留累计概率小于top_p超参部分的预测分布,然后进行采样。源码中的scatter函数解析可以参考torch.scatter函数详解

  1. class TopPLogitsWarper(LogitsWarper):
  2. """
  3. :class:`transformers.LogitsWarper` that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <=
  4. prob_cut_off.
  5. Args:
  6. top_p (:obj:`float`):
  7. If set to < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are
  8. kept for generation.
  9. filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
  10. All filtered values will be set to this float value.
  11. min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
  12. Minimum number of tokens that cannot be filtered.
  13. """
  14. def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
  15. if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
  16. raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
  17. self.top_p = top_p
  18. self.filter_value = filter_value
  19. self.min_tokens_to_keep = min_tokens_to_keep
  20. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  21. #对cur-step的分数进行降序排序
  22. sorted_logits, sorted_indices = torch.sort(scores, descending=True)
  23. #计算累计概率,注意到越到后面,累计概率越接近于1,且该变量的第一个位置是预测分布中概率最高的token
  24. cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
  25. # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
  26. #标记累计概率大于top_p超参的位置
  27. sorted_indices_to_remove = cumulative_probs > self.top_p
  28. if self.min_tokens_to_keep > 1:
  29. # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
  30. sorted_indices_to_remove[..., : self.min_tokens_to_keep - 1] = 0
  31. # Shift the indices to the right to keep also the first token above the threshold
  32. #这里主要是为了保证第一个位置肯定能取到,将remove变量右移了一位,仔细思考其实这个对最终的结果影响并不大。
  33. #举两个例子,[True,...True]全不要和[False,False,True,...,True]保留前面一部分
  34. #对于第一个例子,右移一位其实没影响;对于第二个例子,右移相当于在False,True边界处多要了一个token,其他的并不会受到影响。
  35. sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
  36. #保证留下第一个预测token,毕竟第一个token是cur-step中置信度最高的。
  37. sorted_indices_to_remove[..., 0] = 0
  38. # scatter sorted tensors to original indexing
  39. # 这里跟英文的释义其实差不多,就是将sorted_indices_to_remove变量根据sorted_indices排序下标重新排列到与scores变量一一对应的位置,表示对于scores对应的值是否保留。
  40. # 可以自己对排序后的logits按照排序下标indices进行scatter试试就知道了,会恢复到未排序前的初始状态。
  41. indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
  42. # 对remove变量中True部分对应的scores变量位置替换为self.filter_value(-inf)
  43. scores = scores.masked_fill(indices_to_remove, self.filter_value)
  44. return scores

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号