Confused by calculation of perplexity in docs

Hello everyone,

I want to use perplexity for a task in an NLP project I’m working on. I was reading the :hugs: docs on transformers and perplexity here and I was baffled by this piece of code:

import torch
from tqdm import tqdm

max_length = model.config.n_positions
stride = 512

nlls = []
for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
    begin_loc = max(i + stride - max_length, 0)
    end_loc = min(i + stride, encodings.input_ids.size(1))
    trg_len = end_loc - i    # may be different from stride on last loop
    input_ids = encodings.input_ids[:,begin_loc:end_loc].to(device)
    target_ids = input_ids.clone()
    target_ids[:,:-trg_len] = -100

    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)
        neg_log_likelihood = outputs[0] * trg_len

    nlls.append(neg_log_likelihood)

ppl = torch.exp(torch.stack(nlls).sum() / end_loc)

What is tripping me up is… shouldn’t line 11 read trg_len = end_loc - begin_loc instead of trg_len = end_loc - i?

As I understand the code, in each iteration of the loop we are moving ‘to the right’ in the encodings by stride steps, and selecting a context+target from begin_loc to end_loc to calculate per-word conditional probability of the target word (as per line 12). Moreover, in the typical iteration (not too close to the beginning or end of the encodings), begin_loc = i + stride - max_length and end_loc = i + stride, so that the length of the retrieved context+target is simply max_length. Going by it’s name, isn’t this length what is supposed to be assigned to the trg_len variable?

Also, according to the text, to skip accumulating conditional probabilities of words in the context, in each loop all of the target ids are set to -100 except for the very last one, and hence line 14 of the code. But line 14 only makes sense if trg_len is actually the length of the target ids, which it isn’t if defined as in the code. Right?

Am I making sense or am I making some mistake I am not seeing here?

EDIT: Giving it some more thought, if what I said above makes any sense, then line 14 should be target_ids[:,:-1] = -100 instead of target_ids[:,:-trg_len] = -100, I think. What am I missing here?