Currently, Seq2SeqTrainingArguments
supports gradient_checkpointing
and Seq2SeqTrainer
accepts the config. But when using it with an EncoderDecoderModel
it doesn’t allow gradient checkpointing.
Q: What other models can seq2seq take, other than EncoderDecoderModel
that supports gradient checkpointing?
(Curious, about how else to use gradient checkpointing with Seq2SeqTrainingArguments
, so I’m asking the question above)
Request: Can EncoderDecoderModel
support gradient checkpointing?
[Code]:
import torch
from datasets import load_dataset
from transformers import EncoderDecoderModel
from transformers import AutoTokenizer
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
multibert = EncoderDecoderModel.from_encoder_decoder_pretrained(
"bert-base-multilingual-uncased", "bert-base-multilingual-uncased"
)
# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
output_dir="./",
evaluation_strategy="steps",
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
predict_with_generate=True,
logging_steps=2, # set to 1000 for full training
save_steps=16, # set to 500 for full training
eval_steps=4, # set to 8000 for full training
warmup_steps=1, # set to 2000 for full training
max_steps=16, # delete for full training
# overwrite_output_dir=True,
save_total_limit=1,
#fp16=True,
gradient_checkpointing=True
)
# instantiate trainer
trainer = Seq2SeqTrainer(
model=multibert,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_data.with_format("torch"),
eval_dataset=train_data.with_format("torch"),
)
trainer.train()
[out]:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_28/1791463006.py in <module>
29 )
30
---> 31 trainer.train()
/opt/conda/lib/python3.7/site-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1411 resume_from_checkpoint=resume_from_checkpoint,
1412 trial=trial,
-> 1413 ignore_keys_for_eval=ignore_keys_for_eval,
1414 )
1415
/opt/conda/lib/python3.7/site-packages/transformers/trainer.py in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
1492 # Activate gradient checkpointing if needed
1493 if args.gradient_checkpointing:
-> 1494 self.model.gradient_checkpointing_enable()
1495
1496 model = self._wrap_model(self.model_wrapped)
/opt/conda/lib/python3.7/site-packages/transformers/modeling_utils.py in gradient_checkpointing_enable(self)
1515 """
1516 if not self.supports_gradient_checkpointing:
-> 1517 raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
1518 self.apply(partial(self._set_gradient_checkpointing, value=True))
1519
ValueError: EncoderDecoderModel does not support gradient checkpointing.