Subject: Issues with Custom Model Saving Behavior Using Trainer Class in LVLM Training

I have a question regarding the behavior of the Trainer class in LVLM training when using Seq2SeqTrainingArguments to specify the model saving criteria (e.g., steps, epochs). When setting up the training with these parameters, the model saves without any issues. However, I encountered a problem where the model does not start saving when I implemented the following code to save the model at specific step counts or epochs. I would appreciate any advice on how to resolve this issue.

Implemented Code:

from transformers import TrainerCallback
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.training_args import TrainingArguments

class CustomModelSaveCallback(TrainerCallback):
    def __init__(self):
        self.trainer = None  # No Trainer instance is set at initialization
    def set_trainer(self, trainer_instance):
        self.trainer = trainer_instance

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.global_step % (state.max_steps//args.save_total_limit) == 0:
            output_dir = f"{args.output_dir}/checkpoint-{int(state.global_step)}"
            self.trainer.save_model(output_dir)
            print(f"Saving model to {args.output_dir} at step {state.global_step}")

Trainer has a save_model function below:

def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
    if self.fsdp is not None:
        if output_dir is None:
            output_dir = self.args.output_dir
        from torch.distributed.fsdp import (
            FullyShardedDataParallel as FSDP,
            FullStateDictConfig,
            StateDictType,
        )
        save_policy = FullStateDictConfig(offload_to_cpu=False, rank0_only=False)
        # print(StateDictType.FULL_STATE_DICT)
        # print(save_policy)
        with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, save_policy):
            cpu_state_dict = self.model.state_dict()
        if self.args.should_save:
            self._save(output_dir, state_dict=cpu_state_dict)  # noqa
        # Push to the Hub when `save_model` is called by the user.
        if self.args.push_to_hub and not _internal_call:
            self.push_to_hub(commit_message="Model save")
    else:
        super().save_model(output_dir, _internal_call)

Environment:

  • torch 2.1.2+cu121
  • transformers 4.29.0
  • Machine: DGX H100
  • Training way: FSDP

Regarding save_model, given the ample machine performance, I am not implementing rank0_only or offloading to CPU.