Transformers.generator_utils function source code analysis RepetitionPenaltyLogitsProcessor

It mainly records the source code to solve the problem of repeated phrases in text generation, and the code has specific operation analysis.

class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
    r"""
    :class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences.

    Args:
        repetition_penalty (:obj:`float`):
            The parameter for repetition penalty. 1.0 means no penalty. See `this paper
            <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
    """

    def __init__(self, penalty: float):
        if not isinstance(penalty, float) or not (penalty > 0):
            raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

        self.penalty = penalty

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        #scores为cur-step的词表分布[batch,seq,vocab_size],input_ids为输入decoder的文本序列[batch,seq],则score则是获取当前已经生成文本序列的token概率
        score = torch.gather(scores, 1, input_ids) 

        # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
        #减少已经出现的token的概率
        score = torch.where(score < 0, score * self.penalty, score / self.penalty) 
        
        #将减少后的概率重分配到原始的cur-step词表分布中
        scores.scatter_(1, input_ids, score) 
        return scores

Guess you like

Origin blog.csdn.net/yangyanbao8389/article/details/121651056