foreword
- This article focuses on
Pytorch
how to calculate perplexity in (ppl
) - Why can it be represented by model loss
ppl
How to calculate
When given a word-segmented sequence X = ( x 0 , x 1 , … , xt ) X = (x_0, x_1, \dots,x_t)X=(x0,x1,…,xt) ,ppl
the calculation formula is:
- 其中 p θ ( x i ∣ x < i ) p_\theta(x_i|x_{<i}) pi(xi∣x<i) is based oniisequence preceding i , iilog-likelihood of i token
import torch
from tqdm import tqdm
max_length = model.config.n_positions
stride = 512
seq_len = encodings.input_ids.size(1)
nlls = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
end_loc = min(begin_loc + max_length, seq_len)
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100
with torch.no_grad():
outputs = model(input_ids, labels=target_ids)
# loss is calculated using CrossEntropyLoss which averages over valid labels
# N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
# to the left by 1.
neg_log_likelihood = outputs.loss
nlls.append(neg_log_likelihood)
prev_end_loc = end_loc
if end_loc == seq_len:
break
ppl = torch.exp(torch.stack(nlls).mean())
Here we can see that this shows that we can represent neg_log_likelihood = output.loss
the output of the model .CrossEntropyLoss
ppl
Why
Cross-entropy loss function formula ( pytorch
it is not directly calculated according to this formula, but other processing is also done)
- where yyy is the real ground-truth label
- y ^ \hat{y}y^is the label predicted by the model
- C C C is the number of categories, which can be regarded as the vocabulary size here
In the generation task, since each yi y_iyiOnly one position is 1, and the rest are 0. In fact, the above formula is − log ( yi ) -log({y_{i}})−log(yi) , then for a sequenceXXXcross-entropy loss
, we averageeach tokenKaTeX parse error: {equation} can be used only in display mode., that isppl
. Therefore, in actual calculation, we usecross-entropy loss
to represent a sentenceppl