Batched BertForMaskedLM inference loss issue

Hello there,

I want to test a pre-trained model (of a low resource language) on its ability to predict masked words. One way to do this is to compute the pseudo-perplexity of the model (based on this paper). In short the approach is as follows:

  1. Given a sentence, create N copies of itself for each token in it.
  2. Sequentially, insert a mask token, one in each copy for each respective token, all while producing its label tensor as well as its attention tensor.
  3. Compute the loss for each masked token, by forwarding the masked tensor, the attention tensor, and the labels tensor.
  4. Compute PPPL as given in the paper, in other words, the exponential of the total loss

The problem with this is that it’s expensive to do so for every single token when we have a large number of sentences. In order to make this cheaper, I decided to produce batches by concatenating all the copies of the sentences into one large tensor and then split them into batches. The problem with this solution is that the resulting loss is different for each batch size. I don’t
know why is the loss changing, am I doing something wrong? Also, I would really appreciate any other comments on my code to make it work faster. Thank you very much and have a wonderful day!

Here is my code:

# Function to create copies and insert masks
def repeat_mask_across_sentence(model, tensor_input, mask_token_id):
	num_masks = torch.count_nonzero(tensor_input)-2
	repeat_input = tensor_input.repeat(num_masks, 1)
	mask = torch.ones(tensor_input.size(-1) - 1).diag(1)[:num_masks]
	masked_input = repeat_input.masked_fill(mask == 1, mask_token_id)
	label = repeat_input.masked_fill(masked_input != mask_token_id, -100)
	return masked_input, label, num_masks

# Loading the pretrained model from the huggigface platform
config = BertConfig.from_pretrained(model_path)
tokenizer = BertTokenizer.from_pretrained(model_path, use_fast = True)
model = BertForMaskedLM.from_pretrained(model_path, config=config).to("cuda:0")

# Tokenize all sentences, padding is necessary since we need to fit them into batches
tensors = tokenizer(sentences, padding=True, return_tensors='pt')

loss_per_token, masked_inputs, attention_masks, labels = [], [], [], []

# Creating a masked_input_id tensor, attention mask tensor, and a labels tensor for each sentence
for input_id, attention_mask, token_type_id in zip(tensors["input_ids"], tensors["attention_mask"], tensors['token_type_ids']):
	masked_input, label, n = repeat_mask_across_sentence(model, input_id)
	masked_inputs.append(masked_input)
	attention_masks.append(attention_mask.repeat(masked_input.size(dim=0), 1))
	labels.append(label)

# Merge and split the tensors based on the given batch size
masked_inputs = torch.cat(masked_inputs, 0).split(batch_size)
attention_masks = torch.cat(attention_masks, 0).split(batch_size)
labels = torch.cat(labels, 0).split(batch_size)

# Calculate the loss for each token
for token_ids, attention, label in zip(masked_inputs, attention_masks, labels):
	output = model(input_ids=token_ids.to("cuda:0"), attention_mask=attention.to("cuda:0"), labels=label.to("cuda:0"))
	logits_flat = output['logits'].view(-1, len(tokenizer)).to("cpu")
	labels_flat = label.view(-1).to("cpu")
	losses = F.cross_entropy(logits_flat, labels_flat, reduction='none')
	loss_per_token.extend([x.item() for x, y in zip(losses, labels_flat) if y!=-100])
	del output, logits_flat, labels_flat, losses
	torch.cuda.empty_cache()

print("PPPL:", math.exp(sum(loss_per_token)/len(loss_per_token)))