BartForConditionalGeneration "logits" shape is wrong/unexpected

Using BartForConditionalGeneration with a batch_size = 2 … all the inputs look right but when I examine the logits the shape is torch.Size([2, 1, 50264])

The inputs are:

x['input_ids'].shape => torch.Size([2, 256])
x['attention_mask'].shape => torch.Size([2, 256])
x['decoder_input_ids'].shape => torch.Size([2, 68])

What may I be doing wrong? Or is there a bug in the model?

Could you give us the command you produce the logits? Did you just call the model or using model.generate?

Btw, 50264 is Bart vocab_size

It seems the problem occurs when passing decoder_input_ids. Here’s my code:

logits = self.hf_model(x['input_ids'], x['attention_mask'], x['decoder_input_ids'], labels=None, return_dict=True).logits

where self.hf_model is an instance of BartForConditionalGeneration.

returns a tensor with shape torch.Size([2, 1, 50264]) when the expected is torch.Size([2, 68, 50264])

SOLVED

Apparently, if even you are calculating the loss on your own, you have to pass in labels if you pass ing decoder_input_ids.

Ya’ll may want to add this to the docs or update the behavior to not require the labels argument.

Thanks - wg

1 Like

Great that you solved it!
If you want to compute loss on your own, perhaps you can use BartModel (does not require labels) instead of BartForConditionalGeneration

=======================
About the original question, to me, the following command gave the correct outputs & shape (I used TF). But it seems we need ‘max_length’ padding

model = TFBartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')

src_texts = ['My friends are cool but they eat too many carbs. I really want them to be healthy, so I buy them vegetable.']
tgt_texts = ['I buy them vegetable.']
x = tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts, return_tensors='tf',padding='max_length')

out = model(x)

Note In my case, padding has to be specified as ‘max_length’ , where in the other case, model calling failed.