[Feature Request] Gradient Checkpointing for EncoderDecoderModel

Currently, Seq2SeqTrainingArguments supports gradient_checkpointing and Seq2SeqTrainer accepts the config. But when using it with an EncoderDecoderModel it doesn’t allow gradient checkpointing.

Q: What other models can seq2seq take, other than EncoderDecoderModel that supports gradient checkpointing?

(Curious, about how else to use gradient checkpointing with Seq2SeqTrainingArguments, so I’m asking the question above)

Request: Can EncoderDecoderModel support gradient checkpointing?

[Code]:

import torch

from datasets import load_dataset
from transformers import EncoderDecoderModel
from transformers import AutoTokenizer
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

multibert = EncoderDecoderModel.from_encoder_decoder_pretrained(
    "bert-base-multilingual-uncased", "bert-base-multilingual-uncased"
)

# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
    output_dir="./",
    evaluation_strategy="steps",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    predict_with_generate=True,
    logging_steps=2,  # set to 1000 for full training
    save_steps=16,    # set to 500 for full training
    eval_steps=4,     # set to 8000 for full training
    warmup_steps=1,   # set to 2000 for full training
    max_steps=16,     # delete for full training
    # overwrite_output_dir=True,
    save_total_limit=1,
    #fp16=True, 
    gradient_checkpointing=True
)


# instantiate trainer
trainer = Seq2SeqTrainer(
    model=multibert,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_data.with_format("torch"),
    eval_dataset=train_data.with_format("torch"),
)

trainer.train()

[out]:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_28/1791463006.py in <module>
     29 )
     30 
---> 31 trainer.train()

/opt/conda/lib/python3.7/site-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1411             resume_from_checkpoint=resume_from_checkpoint,
   1412             trial=trial,
-> 1413             ignore_keys_for_eval=ignore_keys_for_eval,
   1414         )
   1415 

/opt/conda/lib/python3.7/site-packages/transformers/trainer.py in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1492         # Activate gradient checkpointing if needed
   1493         if args.gradient_checkpointing:
-> 1494             self.model.gradient_checkpointing_enable()
   1495 
   1496         model = self._wrap_model(self.model_wrapped)

/opt/conda/lib/python3.7/site-packages/transformers/modeling_utils.py in gradient_checkpointing_enable(self)
   1515         """
   1516         if not self.supports_gradient_checkpointing:
-> 1517             raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
   1518         self.apply(partial(self._set_gradient_checkpointing, value=True))
   1519 

ValueError: EncoderDecoderModel does not support gradient checkpointing.

Hi,

I opened a PR here to add support for it: [EncoderDecoderModel] Add support for gradient checkpointing by NielsRogge · Pull Request #19990 · huggingface/transformers · GitHub

1 Like

Thanks the PR looks great!! Gave my thumbs up there.

Let me know if a comment or some local test is needed.

I’m afraid that’s not the biggest problem. I am running Vicuna 13B on two GPUs using FSDP. As a basis for the generation, I took the LLaMa example from the official Meta repository. The problem is that the forward method is extremely slow (on the order of a minute to predict one token). At the same time, the base LLaMa model from Meta works orders of magnitude faster in the same FSDP mode.

from typing import List

import torch

from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaModel, LlamaConfig


class LLaMA:
    def __init__(self, model: LlamaModel, tokenizer: LlamaTokenizer):
        self.model = model
        self.tokenizer = tokenizer

    @torch.inference_mode()
    def generate(
        self,
        prompts: List[str],
        max_gen_len: int,
        temperature: float = 0.8,
        top_p: float = 0.95,
    ) -> List[str]:
        bsz = len(prompts)
        config = self.model.module.config if isinstance(self.model, torch.distributed.fsdp.FullyShardedDataParallel) \
            else self.model.config
        # assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

        prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]

        min_prompt_size = min([len(t) for t in prompt_tokens])
        max_prompt_size = max([len(t) for t in prompt_tokens])

        total_len = min(config.max_position_embeddings, max_gen_len + max_prompt_size)

        tokens = torch.full((bsz, total_len), -1).cuda().long()
        for k, t in enumerate(prompt_tokens):
            tokens[k, : len(t)] = torch.tensor(t).long()
        input_text_mask = tokens != -1
        start_pos = min_prompt_size
        prev_pos = 0
        for cur_pos in range(start_pos, total_len):
            logits = self.model.forward(tokens[:, prev_pos:cur_pos]).logits[:, -1, :]
            if temperature > 0:
                probs = torch.softmax(logits / temperature, dim=-1)
                next_token = sample_top_p(probs, top_p)
            else:
                next_token = torch.argmax(logits, dim=-1)
            next_token = next_token.reshape(-1)
            # only replace token if prompt has already been generated
            next_token = torch.where(
                input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
            )
            tokens[:, cur_pos] = next_token
            if next_token == config.eos_token_id:
                break
            # prev_pos = cur_pos

        decoded = []
        for i, t in enumerate(tokens.tolist()):
            # cut to max gen len
            t = t[: len(prompt_tokens[i]) + max_gen_len]
            # cut to eos tok if any
            try:
                t = t[: t.index(self.tokenizer.eos_token_id)]
            except ValueError:
                pass
            decoded.append(self.tokenizer.decode(t))
        return decoded


def sample_top_p(probs, p):
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = torch.gather(probs_idx, -1, next_token)
    return next_token