Extracting the decoder of an encoder-decoder model

I am trying to build a video-to-text model using a Huggingface VisionEncoderDecoderModel. For the encoder, I’m using VideoMAE. Because the sequence length for videos is long, I want to use the decoder from Longformer Encoder-Decoder (LED). Because LED is an encoder-decoder model, I am extracting its decoder to construct my model like this

enc = "MCG-NJU/videomae-base"
dec = "allenai/led-base-16384"
encoder = AutoModel.from_pretrained(enc)
enc_dec = AutoModel.from_pretrained(dec)
model = VisionEncoderDecoderModel(encoder=encoder, decoder=enc_dec.get_decoder())
...

When I try to do inference with the model like this

with torch.no_grad():
    output = model(pixel_values=vids, decoder_input_ids=texts)
    print(output.last_hidden_state.shape)

I get this error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[40], line 7
      1 with torch.no_grad():
----> 2     output = model(pixel_values=vids, decoder_input_ids=concs)
      3     print(output.last_hidden_state.shape)

File ~/miniconda3/envs/vit/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/vit/lib/python3.10/site-packages/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py:638, in VisionEncoderDecoderModel.forward(self, pixel_values, decoder_input_ids, decoder_attention_mask, encoder_outputs, past_key_values, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs)
    633     else:
    634         return decoder_outputs + encoder_outputs
    636 return Seq2SeqLMOutput(
    637     loss=loss,
--> 638     logits=decoder_outputs.logits,
    639     past_key_values=decoder_outputs.past_key_values,
    640     decoder_hidden_states=decoder_outputs.hidden_states,
    641     decoder_attentions=decoder_outputs.attentions,
    642     cross_attentions=decoder_outputs.cross_attentions,
    643     encoder_last_hidden_state=encoder_outputs.last_hidden_state,
    644     encoder_hidden_states=encoder_outputs.hidden_states,
    645     encoder_attentions=encoder_outputs.attentions,
    646 )

AttributeError: 'BaseModelOutputWithPastAndCrossAttentions' object has no attribute 'logits'

I suspect it has something to do with the fact that I’m not using the whole encoder-decoder model, but I’m not sure. Is there a way to get the decoder that I’m extracting to return an output that has logits?