The correct way to load an EncoderDecoderModel from pre-trained encoder and decoder checkpoints

I am trying to instantiate an EncoderDecoderModel from checkpoints of two pre-trained language models (encoder: BigBirdForMaskedLM and decoder: BigBirdForCausalLM) as follows:

encdec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(
        "../models/pretrained/enc/checkpoint-540000/", 
        "../models/pretrained/dec/checkpoint-1820000/"
)

config_encoder = encdec_model.config.encoder
config_decoder  = encdec_model.config.decoder
# set decoder config to causal lm
config_decoder.is_decoder=True
config_decoder.add_cross_attention=True

somehow the resulting EncoderDecoderModel is not behaving as expected: It trains without any errors using the Trainer but crashes when calling .generate on sequences longer than the decoder’s max_position_embeddings (i.e. 512). This while the encoder has much longer input span (10240) and should be able to handle that input length. The last stack of the error when calling generate is as follows:

~/anaconda3/envs/routing/lib/python3.6/site-packages/transformers/models/big_bird/modeling_big_bird.py in forward(self, input_ids, token_type_ids, position_ids, inputs_embeds, past_key_values_length)
    305 
    306         position_embeddings = self.position_embeddings(position_ids)
--> 307         embeddings += position_embeddings
    308 
    309         embeddings = self.dropout(embeddings)

RuntimeError: output with shape [2, 1, 768] doesn't match the broadcast shape [2, 0, 768]

I think this error might related to the mismatching tensor sizes caused by incorrect loading of the model weights.

I tried loading the config files from the checkpoint into an EncoderDecoderConfig and then instantiate a fresh EncoderDecoderModel and that raised the same error even without loading the pretrained checkpoints.

What seemed to fix the error was to instantiate a fresh EncoderDecoderModel (without using the checkpoint config files).

I would appreciate any ideas on why I may be getting this error when calling .generate (even though the model trains fine) and what is the correct way to load an EncoderDecoderModel from pretrained checkpoints.

I am also including the pre-trained encoder and decoder configs here for reference, since they seem to be related to or causing the problem:

Encoder config:

{
  "architectures": [
    "BigBirdForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "attention_type": "block_sparse",
  "block_size": 64,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu_new",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 10240,
  "model_type": "big_bird",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_random_blocks": 8,
  "pad_token_id": 0,
  "rescale_embeddings": false,
  "sep_token_id": 66,
  "transformers_version": "4.8.2",
  "type_vocab_size": 2,
  "use_bias": true,
  "use_cache": true,
  "vocab_size": 32000
}

Decoder config:

{
  "architectures": [
    "BigBirdForCausalLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "attention_type": "original_full",
  "block_size": 64,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu_new",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": true,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "big_bird",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_random_blocks": 3,
  "pad_token_id": 0,
  "rescale_embeddings": false,
  "sep_token_id": 66,
  "transformers_version": "4.8.2",
  "type_vocab_size": 2,
  "use_bias": true,
  "use_cache": true,
  "vocab_size": 32000
}

Thank you for your help!