EncoderDecoderModel loaded from pre-trained checkpoints fails when calling generate

I am trying to instantiate and fine-tune an EncoderDecoderModel from checkpoints of two pre-trained language models (encoder: BigBirdForMaskedLM and decoder: BigBirdForCausalLM) as follows:

encdec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(
        "../models/pretrained/enc/checkpoint-540000/", 
        "../models/pretrained/dec/checkpoint-1820000/"
)

The forward works fine and as a result the model trains without a bug:

seq2seq_output = encdec_model(
        input_ids=input_ids, 
        decoder_input_ids=decoder_input_ids, 
        labels=labels
)

But when calling generate as follows it fails with a runtime error:

generated = encdec_model.generate(
        input_ids, 
        decoder_start_token_id=2,
        num_beams=4, max_length=10
)

Here is the error stack:

RuntimeError                              Traceback (most recent call last)
<ipython-input-14-276b11d42bb0> in <module>
      5         input_ids,
      6         decoder_start_token_id=2,
----> 7         num_beams=4, max_length=10
      8 )

~/anaconda3/envs/routing/lib/python3.6/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     26         def decorate_context(*args, **kwargs):
     27             with self.__class__():
---> 28                 return func(*args, **kwargs)
     29         return cast(F, decorate_context)
     30 

~/anaconda3/envs/routing/lib/python3.6/site-packages/transformers/generation_utils.py in generate(self, input_ids, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, repetition_penalty, bad_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, max_new_tokens, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, synced_gpus, **model_kwargs)
   1061                 return_dict_in_generate=return_dict_in_generate,
   1062                 synced_gpus=synced_gpus,
-> 1063                 **model_kwargs,
   1064             )
   1065 

~/anaconda3/envs/routing/lib/python3.6/site-packages/transformers/generation_utils.py in beam_search(self, input_ids, beam_scorer, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)
   1792                 return_dict=True,
   1793                 output_attentions=output_attentions,
-> 1794                 output_hidden_states=output_hidden_states,
   1795             )
   1796 

~/anaconda3/envs/routing/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/routing/lib/python3.6/site-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py in forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs)
    448             past_key_values=past_key_values,
    449             return_dict=return_dict,
--> 450             **kwargs_decoder,
    451         )
    452 

~/anaconda3/envs/routing/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/routing/lib/python3.6/site-packages/transformers/models/big_bird/modeling_big_bird.py in forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   2551             output_attentions=output_attentions,
   2552             output_hidden_states=output_hidden_states,
-> 2553             return_dict=return_dict,
   2554         )
   2555 

~/anaconda3/envs/routing/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/routing/lib/python3.6/site-packages/transformers/models/big_bird/modeling_big_bird.py in forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
   2131             token_type_ids=token_type_ids,
   2132             inputs_embeds=inputs_embeds,
-> 2133             past_key_values_length=past_key_values_length,
   2134         )
   2135 

~/anaconda3/envs/routing/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/routing/lib/python3.6/site-packages/transformers/models/big_bird/modeling_big_bird.py in forward(self, input_ids, token_type_ids, position_ids, inputs_embeds, past_key_values_length)
    305 
    306         position_embeddings = self.position_embeddings(position_ids)
--> 307         embeddings += position_embeddings
    308 
    309         embeddings = self.dropout(embeddings)

RuntimeError: output with shape [4, 1, 768] doesn't match the broadcast shape [4, 0, 768]

I must add that when starting a fresh EncoderDecoderModel and calling generate there is no error. The error happens when loading the model from pre-trained checkpoints using from_encoder_decoder_pretrained.

Any help with this would be appreciated. Thank you.