I want to use perplexity for a task in an NLP project I’m working on. I was reading the docs on transformers and perplexity here and I was baffled by this piece of code:
import torch from tqdm import tqdm max_length = model.config.n_positions stride = 512 nlls =  for i in tqdm(range(0, encodings.input_ids.size(1), stride)): begin_loc = max(i + stride - max_length, 0) end_loc = min(i + stride, encodings.input_ids.size(1)) trg_len = end_loc - i # may be different from stride on last loop input_ids = encodings.input_ids[:,begin_loc:end_loc].to(device) target_ids = input_ids.clone() target_ids[:,:-trg_len] = -100 with torch.no_grad(): outputs = model(input_ids, labels=target_ids) neg_log_likelihood = outputs * trg_len nlls.append(neg_log_likelihood) ppl = torch.exp(torch.stack(nlls).sum() / end_loc)
What is tripping me up is… shouldn’t line 11 read
trg_len = end_loc - begin_loc instead of
trg_len = end_loc - i?
As I understand the code, in each iteration of the loop we are moving ‘to the right’ in the encodings by
stride steps, and selecting a context+target from
end_loc to calculate per-word conditional probability of the target word (as per line 12). Moreover, in the typical iteration (not too close to the beginning or end of the encodings),
begin_loc = i + stride - max_length and
end_loc = i + stride, so that the length of the retrieved context+target is simply
max_length. Going by it’s name, isn’t this length what is supposed to be assigned to the
Also, according to the text, to skip accumulating conditional probabilities of words in the context, in each loop all of the target ids are set to -100 except for the very last one, and hence line 14 of the code. But line 14 only makes sense if
trg_len is actually the length of the target ids, which it isn’t if defined as in the code. Right?
Am I making sense or am I making some mistake I am not seeing here?
EDIT: Giving it some more thought, if what I said above makes any sense, then line 14 should be
target_ids[:,:-1] = -100 instead of
target_ids[:,:-trg_len] = -100, I think. What am I missing here?