BartForConditionalGeneration "logits" shape is wrong/unexpected

Using BartForConditionalGeneration with a batch_size = 2 … all the inputs look right but when I examine the logits the shape is torch.Size([2, 1, 50264])

The inputs are:

x['input_ids'].shape => torch.Size([2, 256])
x['attention_mask'].shape => torch.Size([2, 256])
x['decoder_input_ids'].shape => torch.Size([2, 68])

What may I be doing wrong? Or is there a bug in the model?

Could you give us the command you produce the logits? Did you just call the model or using model.generate?

Btw, 50264 is Bart vocab_size

It seems the problem occurs when passing decoder_input_ids. Here’s my code:

logits = self.hf_model(x['input_ids'], x['attention_mask'], x['decoder_input_ids'], labels=None, return_dict=True).logits

where self.hf_model is an instance of BartForConditionalGeneration.

returns a tensor with shape torch.Size([2, 1, 50264]) when the expected is torch.Size([2, 68, 50264])


Apparently, if even you are calculating the loss on your own, you have to pass in labels if you pass ing decoder_input_ids.

Ya’ll may want to add this to the docs or update the behavior to not require the labels argument.

Thanks - wg

1 Like

Great that you solved it!
If you want to compute loss on your own, perhaps you can use BartModel (does not require labels) instead of BartForConditionalGeneration

About the original question, to me, the following command gave the correct outputs & shape (I used TF). But it seems we need ‘max_length’ padding

model = TFBartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')

src_texts = ['My friends are cool but they eat too many carbs. I really want them to be healthy, so I buy them vegetable.']
tgt_texts = ['I buy them vegetable.']
x = tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts, return_tensors='tf',padding='max_length')

out = model(x)

Note In my case, padding has to be specified as ‘max_length’ , where in the other case, model calling failed.