BertForNextSentencePrediction with larger batch size

Is there a way to use BertForNextSentencePrediction in inference mode with a batch size larger than 1? I have some code.

encoding = self.tokenizer(prompt1, prompt2, return_tensors='pt', padding=True, truncation=True, add_special_tokens=True)
print(encoding)
outputs = self.model(**encoding, next_sentence_label=torch.LongTensor([1]), target_batch_size=10)
logits = outputs.logits
#print(logits)

Here prompt1 and prompt2 are lists of sentences. The list is 10 sentences long. I get an error like this:

ValueError: Expected input batch_size (10) to match target batch_size (1).

hey @DLiebman i think the problem is that you’re passing a batch of 10 examples, but only a single label in the next_sentence_label argument. changing your code to the following works for me:

outputs = model(**encoding, next_sentence_label=torch.ones((10,1), dtype=torch.long), target_batch_size=10)
1 Like

yep. something like that will work. thanks.

1 Like