赞
踩
sample函数相较于beam_search函数要简单的多,但是需要注意的一点是,sample需要搭配logits_warper处理器列表使用,相应三类处理器函数解析在下面。sample函数的源码解释如下,比较浅显易懂。
- # auto-regressive generation
- while True:
- # prepare model inputs
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
-
- # forward pass to get next token
- outputs = self(
- **model_inputs,
- return_dict=True,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- )
-
- if synced_gpus and this_peer_finished:
- cur_len = cur_len + 1
- continue # don't waste resources running the code we don't need
- # 获取cur-step的预测输出
- next_token_logits = outputs.logits[:, -1, :]
-
- # pre-process distribution
- next_token_scores = logits_processor(input_ids, next_token_logits)
- next_token_scores = logits_warper(input_ids, next_token_scores)
-
- # Store scores, attentions and hidden_states when required
- if return_dict_in_generate:
- if output_scores:
- scores += (next_token_scores,)
- if output_attentions:
- decoder_attentions += (
- (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
- )
- if self.config.is_encoder_decoder:
- cross_attentions += (outputs.cross_attentions,)
-
- if output_hidden_states:
- decoder_hidden_states += (
- (outputs.decoder_hidden_states,)
- if self.config.is_encoder_decoder
- else (outputs.hidden_states,)
- )
-
- # sample
- # 对重构后的预测分布进行softmax
- probs = F.softmax(next_token_scores, dim=-1)
- # 多项式函数取样
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
-
- # finished sentences should have their next token be a padding token
- if eos_token_id is not None:
- assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
- next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
-
- # update generated ids, model inputs, and length for next step
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
- model_kwargs = self._update_model_kwargs_for_generation(
- outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
- )
- cur_len = cur_len + 1
-
- # if eos_token was found in one sentence, set sentence to finished
- if eos_token_id is not None:
- unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
-
- # stop when each sentence is finished, or if we exceed the maximum length
- if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
- if not synced_gpus:
- break
- else:
- this_peer_finished = True
- return input_ids

transfomrer库中的logits_warper系列处理函数包括TemperatureLogitsWarper、TopKLogitsWarper和TopPLogitsWarper。
- class TemperatureLogitsWarper(LogitsWarper):
- r"""
- :class:`transformers.LogitsWarper` for temperature (exponential scaling output probability distribution).
- Args:
- temperature (:obj:`float`):
- The value used to module the logits distribution.
- """
-
- def __init__(self, temperature: float):
- if not isinstance(temperature, float) or not (temperature > 0):
- raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
-
- self.temperature = temperature
-
- def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
- scores = scores / self.temperature
- return scores

缺点:丢弃掉的部分(Tail)可能会包含很多的词语,这导致我们能选择的词汇较少。而在一些情况下,丢弃掉大部分可能包含的词汇较少,我们能生成较为丰富的文本。
因此, k 值的选择对于生成结果极其重要。
- class TopKLogitsWarper(LogitsWarper):
- r"""
- :class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements.
- Args:
- top_k (:obj:`int`):
- The number of highest probability vocabulary tokens to keep for top-k-filtering.
- filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
- All filtered values will be set to this float value.
- min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
- Minimum number of tokens that cannot be filtered.
- """
-
- def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
- if not isinstance(top_k, int) or top_k <= 0:
- raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
-
- self.top_k = top_k
- self.filter_value = filter_value
- self.min_tokens_to_keep = min_tokens_to_keep
-
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
- top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check
- # Remove all tokens with a probability less than the last token of the top-k
- #torch.topk(scores, top_k)[0][..., -1]获得top_k排序后最后一个值,即把分布中小于该值的位置设置为true。
- indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
- #将indices_to_remove中为true的值替换为self.filter_value(一般为-inf)
- scores = scores.masked_fill(indices_to_remove, self.filter_value)
- return scores

TopPLogitsWarper函数:采用累计概率的形式,保留累计概率小于top_p超参部分的预测分布,然后进行采样。源码中的scatter函数解析可以参考torch.scatter函数详解。
- class TopPLogitsWarper(LogitsWarper):
- """
- :class:`transformers.LogitsWarper` that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <=
- prob_cut_off.
- Args:
- top_p (:obj:`float`):
- If set to < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are
- kept for generation.
- filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
- All filtered values will be set to this float value.
- min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
- Minimum number of tokens that cannot be filtered.
- """
-
- def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
- if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
- raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
-
- self.top_p = top_p
- self.filter_value = filter_value
- self.min_tokens_to_keep = min_tokens_to_keep
-
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
- #对cur-step的分数进行降序排序
- sorted_logits, sorted_indices = torch.sort(scores, descending=True)
- #计算累计概率,注意到越到后面,累计概率越接近于1,且该变量的第一个位置是预测分布中概率最高的token
- cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
-
- # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
- #标记累计概率大于top_p超参的位置
- sorted_indices_to_remove = cumulative_probs > self.top_p
- if self.min_tokens_to_keep > 1:
- # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
- sorted_indices_to_remove[..., : self.min_tokens_to_keep - 1] = 0
- # Shift the indices to the right to keep also the first token above the threshold
- #这里主要是为了保证第一个位置肯定能取到,将remove变量右移了一位,仔细思考其实这个对最终的结果影响并不大。
- #举两个例子,[True,...True]全不要和[False,False,True,...,True]保留前面一部分
- #对于第一个例子,右移一位其实没影响;对于第二个例子,右移相当于在False,True边界处多要了一个token,其他的并不会受到影响。
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
- #保证留下第一个预测token,毕竟第一个token是cur-step中置信度最高的。
- sorted_indices_to_remove[..., 0] = 0
-
- # scatter sorted tensors to original indexing
- # 这里跟英文的释义其实差不多,就是将sorted_indices_to_remove变量根据sorted_indices排序下标重新排列到与scores变量一一对应的位置,表示对于scores对应的值是否保留。
- # 可以自己对排序后的logits按照排序下标indices进行scatter试试就知道了,会恢复到未排序前的初始状态。
- indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
- # 对remove变量中True部分对应的scores变量位置替换为self.filter_value(-inf)
- scores = scores.masked_fill(indices_to_remove, self.filter_value)
- return scores

Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。