Hello everyone,
I want to use perplexity for a task in an NLP project I’m working on. I was reading the docs on transformers and perplexity here and I was baffled by this piece of code:
import torch
from tqdm import tqdm
max_length = model.config.n_positions
stride = 512
nlls = []
for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
begin_loc = max(i + stride - max_length, 0)
end_loc = min(i + stride, encodings.input_ids.size(1))
trg_len = end_loc - i # may be different from stride on last loop
input_ids = encodings.input_ids[:,begin_loc:end_loc].to(device)
target_ids = input_ids.clone()
target_ids[:,:-trg_len] = -100
with torch.no_grad():
outputs = model(input_ids, labels=target_ids)
neg_log_likelihood = outputs[0] * trg_len
nlls.append(neg_log_likelihood)
ppl = torch.exp(torch.stack(nlls).sum() / end_loc)
What is tripping me up is… shouldn’t line 11 read trg_len = end_loc - begin_loc
instead of trg_len = end_loc - i
?
As I understand the code, in each iteration of the loop we are moving ‘to the right’ in the encodings by stride
steps, and selecting a context+target from begin_loc
to end_loc
to calculate per-word conditional probability of the target word (as per line 12). Moreover, in the typical iteration (not too close to the beginning or end of the encodings), begin_loc = i + stride - max_length
and end_loc = i + stride
, so that the length of the retrieved context+target is simply max_length
. Going by it’s name, isn’t this length what is supposed to be assigned to the trg_len
variable?
Also, according to the text, to skip accumulating conditional probabilities of words in the context, in each loop all of the target ids are set to -100 except for the very last one, and hence line 14 of the code. But line 14 only makes sense if trg_len
is actually the length of the target ids, which it isn’t if defined as in the code. Right?
Am I making sense or am I making some mistake I am not seeing here?
EDIT: Giving it some more thought, if what I said above makes any sense, then line 14 should be target_ids[:,:-1] = -100
instead of target_ids[:,:-trg_len] = -100
, I think. What am I missing here?