Mismatch of tensor shapes in CrossEntropyLoss for custom head layer in BART

Hi,

so far I’ve been working with the BartForConditionalGeneration. Now I want to use a custom head layer instead.
In a linear layer after the base models decoder, I want to input the output of the base bart model and additional some numerical data similar to the code here. Following this I came up with the following forward function:

def forward(self, input_ids, tokens, **kwargs):
    labels = kwargs.get('labels')
    attn_mask = kwargs.get('attention_mask')
    out = self.model_base(input_ids, attention_mask=attn_mask)
    token_features = tokens.unsqueeze(1)
    concat= torch.concat((out[0][:, 0, :], token_features), dim=-1)
    out_lin = self.custom_layer(concat)

    loss_fct = torch.nn.CrossEntropyLoss()
    masked_lm_loss = loss_fct(out_lin.view(-1, self.model.config.vocab_size), labels.view(-1))

where out_lin is the following linear layer:

self.custom_layer = torch.nn.Linear(in_features = self.hidden_dim + self.token_dim, out_features = self.model.config.vocab_size)

For the loss function I took orientation from the original code for the BartForConditionalGeneration:

outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            decoder_cached_states=decoder_cached_states,
            use_cache=use_cache,
        )
lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias)
outputs = (lm_logits,) + outputs[1:]  # Add cache, hidden states and attention if they are here
if lm_labels is not None:
      loss_fct = nn.CrossEntropyLoss()
       # TODO(SS): do we need to ignore pad tokens in lm_labels?
       masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), lm_labels.view(-1))
       outputs = (masked_lm_loss,) + outputs

return outputs

The error I obtain is

---------------------------------------------------------------------------
File c:\Users\M\Anaconda\envs\simp_env\lib\site-packages\pytorch_lightning\trainer\call.py:38, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     36         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
     37     else:
---> 38         return trainer_fn(*args, **kwargs)
     40 except _TunerExitException:
...
   3024 if size_average is not None or reduce is not None:
   3025     reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3026 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)

ValueError: Expected input batch_size (6) to match target batch_size (1536).

So far I understand that 6 is the batch size of my data output from the custom linear layer (which is torch.Size([6, 50267]) where 50267 is the self.final_logits_bias/vocab_size). My labels have the shape torch.Size([6, 256]) which when flattened leads to the 1536.
As my labels have the same shape as before and my layer seems to me the same as the one from the ConditionalGenerationModel which I used before, I am unsure why I suddenly receive this size incompatibility issue, when I did not before.

Furthermore, I am unsure why the first code referenced here uses only hidden_states.last_hidden_state[:, 0, :] so only batch_size and hidden_size but not the sequence length. Without it my data has the shape torch.Size([6, 256, 768]).

I would be thankful for any guidance on how to make the tensors compatible and make the custom layer work.
Did I misunderstand the examples mentioned above?