The size of tensor a (146) must match the size of tensor b (1214) at non-singleton dimension 1

Hello,

I am currently trying to use Audio Spectrogram Transformer for Emotion Recognition using the MELD dataset, i have came to a bit of a wall since i got an error of

11 frames
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

<ipython-input-23-1eec2900cfb9> in forward(self, text_input, audio_input)
      8     def forward(self, text_input, audio_input):
      9         text_output = self.text_model(**text_input).hidden_states[-1][:, 0, :]
---> 10         audio_output = self.audio_model(audio_input).last_hidden_state
     11         concatenated = torch.cat((text_output, audio_output), dim=-1)
     12         logits = self.classifier(concatenated)

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

/usr/local/lib/python3.10/dist-packages/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py in forward(self, input_values, head_mask, labels, output_attentions, output_hidden_states, return_dict)
    571         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    572 
--> 573         outputs = self.audio_spectrogram_transformer(
    574             input_values,
    575             head_mask=head_mask,

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

/usr/local/lib/python3.10/dist-packages/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py in forward(self, input_values, head_mask, output_attentions, output_hidden_states, return_dict)
    488         head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
    489 
--> 490         embedding_output = self.embeddings(input_values)
    491 
    492         encoder_outputs = self.encoder(

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

/usr/local/lib/python3.10/dist-packages/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py in forward(self, input_values)
     85         distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
     86         embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
---> 87         embeddings = embeddings + self.position_embeddings
     88         embeddings = self.dropout(embeddings)
     89 

RuntimeError: The size of tensor a (146) must match the size of tensor b (1214) at non-singleton dimension 1

and this is where I got confused because i have no clue where the tensor b of value 1214 came from

from another part of my error is shown to be from my code of

class MultimodalModel(nn.Module):
    def __init__(self, text_model, audio_model, num_classes):
        super(MultimodalModel, self).__init__()
        self.text_model = text_model
        self.audio_model = audio_model
        self.classifier = nn.Linear(text_model.config.hidden_size + audio_model.config.hidden_size, num_classes)

    def forward(self, text_input, audio_input):
        text_output = self.text_model(**text_input).hidden_states[-1][:, 0, :]
        audio_output = self.audio_model(audio_input).last_hidden_state
        concatenated = torch.cat((text_output, audio_output), dim=-1)
        logits = self.classifier(concatenated)
        return logits

just to help my train data and test data are in the input form of
Train: torch.Size([9887, 128, 128])
Test: torch.Size([1094, 128, 128])
where the the first number is the number of data I have the second is the n_mels I set during preprocessing and the third is the number of target_frames to keep the entire dataset consistent.

for the AST model, I am using the pre-trained model of

ast_model_name = "MIT/ast-finetuned-audioset-10-10-0.4593"
ast_model = AutoModelForAudioClassification.from_pretrained(ast_model_name)

I truly appreciate all the help I can get and thank you for reading this post
if more code is needed to get a better understanding i will provide them
once again Thank You