It mainly records the source code to solve the problem of repeated phrases in text generation, and the code has specific operation analysis.
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
:class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences.
Args:
repetition_penalty (:obj:`float`):
The parameter for repetition penalty. 1.0 means no penalty. See `this paper
<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
"""
def __init__(self, penalty: float):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = penalty
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
#scores为cur-step的词表分布[batch,seq,vocab_size],input_ids为输入decoder的文本序列[batch,seq],则score则是获取当前已经生成文本序列的token概率
score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
#减少已经出现的token的概率
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
#将减少后的概率重分配到原始的cur-step词表分布中
scores.scatter_(1, input_ids, score)
return scores