EncoderDecoderModel loaded from pre-trained checkpoints fails when calling generate

I am trying to instantiate and fine-tune an EncoderDecoderModel from checkpoints of two pre-trained language models (encoder: BigBirdForMaskedLM and decoder: BigBirdForCausalLM) as follows:

encdec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(
        "../models/pretrained/enc/checkpoint-540000/", 
        "../models/pretrained/dec/checkpoint-1820000/"
)

The forward works fine and as a result the model trains without a bug:

seq2seq_output = encdec_model(
        input_ids=input_ids, 
        decoder_input_ids=decoder_input_ids, 
        labels=labels
)

But when calling generate as follows it fails with a runtime error:

generated = encdec_model.generate(
        input_ids, 
        decoder_start_token_id=2,
        num_beams=4, max_length=10
)

Here is the error stack:

RuntimeError                              Traceback (most recent call last)
<ipython-input-14-276b11d42bb0> in <module>
      5         input_ids,
      6         decoder_start_token_id=2,
----> 7         num_beams=4, max_length=10
      8 )

~/anaconda3/envs/routing/lib/python3.6/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     26         def decorate_context(*args, **kwargs):
     27             with self.__class__():
---> 28                 return func(*args, **kwargs)
     29         return cast(F, decorate_context)
     30 

~/anaconda3/envs/routing/lib/python3.6/site-packages/transformers/generation_utils.py in generate(self, input_ids, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, repetition_penalty, bad_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, max_new_tokens, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, synced_gpus, **model_kwargs)
   1061                 return_dict_in_generate=return_dict_in_generate,
   1062                 synced_gpus=synced_gpus,
-> 1063                 **model_kwargs,
   1064             )
   1065 

~/anaconda3/envs/routing/lib/python3.6/site-packages/transformers/generation_utils.py in beam_search(self, input_ids, beam_scorer, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)
   1792                 return_dict=True,
   1793                 output_attentions=output_attentions,
-> 1794                 output_hidden_states=output_hidden_states,
   1795             )
   1796 

~/anaconda3/envs/routing/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/routing/lib/python3.6/site-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py in forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs)
    448             past_key_values=past_key_values,
    449             return_dict=return_dict,
--> 450             **kwargs_decoder,
    451         )
    452 

~/anaconda3/envs/routing/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/routing/lib/python3.6/site-packages/transformers/models/big_bird/modeling_big_bird.py in forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   2551             output_attentions=output_attentions,
   2552             output_hidden_states=output_hidden_states,
-> 2553             return_dict=return_dict,
   2554         )
   2555 

~/anaconda3/envs/routing/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/routing/lib/python3.6/site-packages/transformers/models/big_bird/modeling_big_bird.py in forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
   2131             token_type_ids=token_type_ids,
   2132             inputs_embeds=inputs_embeds,
-> 2133             past_key_values_length=past_key_values_length,
   2134         )
   2135 

~/anaconda3/envs/routing/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/routing/lib/python3.6/site-packages/transformers/models/big_bird/modeling_big_bird.py in forward(self, input_ids, token_type_ids, position_ids, inputs_embeds, past_key_values_length)
    305 
    306         position_embeddings = self.position_embeddings(position_ids)
--> 307         embeddings += position_embeddings
    308 
    309         embeddings = self.dropout(embeddings)

RuntimeError: output with shape [4, 1, 768] doesn't match the broadcast shape [4, 0, 768]

I must add that when starting a fresh EncoderDecoderModel and calling generate there is no error. The error happens when loading the model from pre-trained checkpoints using from_encoder_decoder_pretrained.

Any help with this would be appreciated. Thank you.

@aliosk Have you found any solution to this?

@raygx hey! Can you provide a minimal code that I can run to reproduce the error?

@RaushanTurganbay Thanks for the response.
I think this has to do with some version mismatch. Hopefully you will give me a solution.

This code is run on Google Colab with T4 GPU

Library Installation & Imports

# Installing all the required libraries
!pip install datasets --q
# !pip install -U transformers --q
!pip install rouge_score --q

#Using the `Trainer` with `PyTorch` requires `accelerate>=0.21.0`
!pip install -U transformers[torch] --q
!pip install accelerate -U --q

import pandas as pd
from datasets import Dataset


from datasets import load_dataset
from transformers import BertTokenizer, BertLMHeadModel
from transformers import DataCollatorWithPadding, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, EarlyStoppingCallback
from rouge_score import rouge_scorer

Load the BertLMHeadModel and tokenizer

model_name = "NepBERTa/NepBERTa"

model = BertLMHeadModel.from_pretrained(model_name, from_tf=True, is_decoder=True)

tokenizer = BertTokenizer.from_pretrained(model_name)

Preprocess the data for fine-tuning

def preprocess_function(examples):
    inputs = examples["text"]
    targets = examples["target"]

    #Tokenize the inputs
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length", return_tensors="pt")

    #Tokenize the targets
    with tokenizer.as_target_tokenizer():
      labels = tokenizer(targets, max_length=512, truncation=True, padding="max_length", return_tensors="pt")

    model_inputs["labels"] = labels.input_ids

    return model_inputs

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding="max_length", return_tensors="pt")

tokenized_train = train_dataset.map(preprocess_function, batched=True)
tokenized_train
tokenized_val = val_dataset.map(preprocess_function, batched=True)
tokenized_val
tokenized_test = test_dataset.map(preprocess_function, batched=True, batch_size=10,  remove_columns=test_dataset.column_names)
tokenized_test

Fine-tune the BertLMHeadModel for text summarization

training_args = Seq2SeqTrainingArguments(
    output_dir="./NepBERTa-finetuned",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=15,
    weight_decay=0.01,
    save_total_limit=5,
    push_to_hub=False,
    load_best_model_at_end=True
)

#Stopping training if validation loss doesn't improve for 3 epochs
early_stopping = EarlyStoppingCallback(early_stopping_patience = 3)


trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=data_collator,
    callbacks=[early_stopping]      #Add the early stopping callback
)

# Get the first batch of the training dataloader
for batch in trainer.get_train_dataloader():
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]

    # Pass the input through the model to get the output
    #outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

    # Print the input and output sizes
    print("Input size:", input_ids.size())
    print("Attention Mask size:", attention_mask.size())
    print("Labels size:", labels.size())
    #print("Output size:", outputs.size())

    break  # Break after processing the first batch

# Continue with trainer.train() to start the training process
trainer.train()

Block Where it Errors

from rouge_score import rouge_scorer

#Evaluate the fine-tuned model using ROUGE
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)`Preformatted text`

references = test_dataset["target"]
predictions = []

for i in range(len(test_dataset["target"])):
  inputs = tokenizer(test_dataset["target"][i], return_tensors="pt", max_length=512, truncation=True, padding="max_length").input_ids
  # print(inputs)
  inputs.to('cuda:0') #if running on GPU
  output_ids = model.generate(inputs, max_new_tokens=10)#, num_beams=4)
  prediction = tokenizer.decode(output_ids[0], skip_special_tokens=True)
  predictions.append(prediction)

rouge_scores = scorer.score_pairs(references, predictions)

print("ROUGE-1:", rouge_scores['rouge1'].fmeasure)
print("ROUGE-2:", rouge_scores['rouge2'].fmeasure)
print("ROUGE-L:", rouge_scores['rougeL'].fmeasure)

The Error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-23-4b5f4f858d4f> in <cell line: 10>()
     12   # print(inputs)
     13   inputs.to('cuda:0') #if running on GPU
---> 14   output_ids = model.generate(inputs, max_new_tokens=10)#, num_beams=4)
     15   prediction = tokenizer.decode(output_ids[0], skip_special_tokens=True)
     16   predictions.append(prediction)

11 frames
/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
    113     def decorate_context(*args, **kwargs):
    114         with ctx_factory():
--> 115             return func(*args, **kwargs)
    116 
    117     return decorate_context

/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py in generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1756 
   1757             # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 1758             result = self._sample(
   1759                 input_ids,
   1760                 logits_processor=prepared_logits_processor,

/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py in _sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
   2395 
   2396             # forward pass to get next token
-> 2397             outputs = self(
   2398                 **model_inputs,
   2399                 return_dict=True,

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

/usr/local/lib/python3.10/dist-packages/transformers/models/bert/modeling_bert.py in forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
   1350             use_cache = False
   1351 
-> 1352         outputs = self.bert(
   1353             input_ids,
   1354             attention_mask=attention_mask,

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

/usr/local/lib/python3.10/dist-packages/transformers/models/bert/modeling_bert.py in forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
   1071                 token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
   1072 
-> 1073         embedding_output = self.embeddings(
   1074             input_ids=input_ids,
   1075             position_ids=position_ids,

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

/usr/local/lib/python3.10/dist-packages/transformers/models/bert/modeling_bert.py in forward(self, input_ids, token_type_ids, position_ids, inputs_embeds, past_key_values_length)
    214         if self.position_embedding_type == "absolute":
    215             position_embeddings = self.position_embeddings(position_ids)
--> 216             embeddings += position_embeddings
    217         embeddings = self.LayerNorm(embeddings)
    218         embeddings = self.dropout(embeddings)

RuntimeError: output with shape [1, 1, 768] doesn't match the broadcast shape [1, 0, 768]

@raygx Thanks for sharing a reproducer. I found the issue.

Bert has a maximum position embedding of 512, and the error is caused any time you go beyond 512 tokens when generating. In your case for inference, I recommend to not use “padding=“max_length””, and let the batch be padded to the max length within the batch. But in case the dataset contains long sentences, this will not help much, and you’d have to truncate inputs. That’s because the Bert model you’re using is not an encoder-decoder, it’s a simple decoder-only bert, so we have to keep the total length (input + generated) below 512 tokens in inference.

In training is didn’t fail because training doesn’t perform generation, but rather does one forward pass with the inputs (source+target).

Here are some resources to improve the training script:

  1. Encoder-decoder model from two BERTs training → Encoder Decoder Models
  2. Train decoder-only model GPT → transformers/examples/pytorch/language-modeling/run_clm.py at main · huggingface/transformers · GitHub
  3. Train BART encoder-decoder model → transformers/examples/pytorch/summarization/run_summarization.py at main · huggingface/transformers · GitHub
1 Like

@RaushanTurganbay Thank you very much for taking time and looking into this. It worked. Thanks. :star_struck: :star_struck: :star_struck: :star_struck: :star_struck:

1 Like