[Feature Request] Gradient Checkpointing for EncoderDecoderModel

Currently, Seq2SeqTrainingArguments supports gradient_checkpointing and Seq2SeqTrainer accepts the config. But when using it with an EncoderDecoderModel it doesn’t allow gradient checkpointing.

Q: What other models can seq2seq take, other than EncoderDecoderModel that supports gradient checkpointing?

(Curious, about how else to use gradient checkpointing with Seq2SeqTrainingArguments, so I’m asking the question above)

Request: Can EncoderDecoderModel support gradient checkpointing?

[Code]:

import torch

from datasets import load_dataset
from transformers import EncoderDecoderModel
from transformers import AutoTokenizer
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

multibert = EncoderDecoderModel.from_encoder_decoder_pretrained(
    "bert-base-multilingual-uncased", "bert-base-multilingual-uncased"
)

# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
    output_dir="./",
    evaluation_strategy="steps",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    predict_with_generate=True,
    logging_steps=2,  # set to 1000 for full training
    save_steps=16,    # set to 500 for full training
    eval_steps=4,     # set to 8000 for full training
    warmup_steps=1,   # set to 2000 for full training
    max_steps=16,     # delete for full training
    # overwrite_output_dir=True,
    save_total_limit=1,
    #fp16=True, 
    gradient_checkpointing=True
)


# instantiate trainer
trainer = Seq2SeqTrainer(
    model=multibert,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_data.with_format("torch"),
    eval_dataset=train_data.with_format("torch"),
)

trainer.train()

[out]:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_28/1791463006.py in <module>
     29 )
     30 
---> 31 trainer.train()

/opt/conda/lib/python3.7/site-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1411             resume_from_checkpoint=resume_from_checkpoint,
   1412             trial=trial,
-> 1413             ignore_keys_for_eval=ignore_keys_for_eval,
   1414         )
   1415 

/opt/conda/lib/python3.7/site-packages/transformers/trainer.py in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1492         # Activate gradient checkpointing if needed
   1493         if args.gradient_checkpointing:
-> 1494             self.model.gradient_checkpointing_enable()
   1495 
   1496         model = self._wrap_model(self.model_wrapped)

/opt/conda/lib/python3.7/site-packages/transformers/modeling_utils.py in gradient_checkpointing_enable(self)
   1515         """
   1516         if not self.supports_gradient_checkpointing:
-> 1517             raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
   1518         self.apply(partial(self._set_gradient_checkpointing, value=True))
   1519 

ValueError: EncoderDecoderModel does not support gradient checkpointing.

Hi,

I opened a PR here to add support for it: [EncoderDecoderModel] Add support for gradient checkpointing by NielsRogge · Pull Request #19990 · huggingface/transformers · GitHub

1 Like

Thanks the PR looks great!! Gave my thumbs up there.

Let me know if a comment or some local test is needed.