I have a question regarding the behavior of the Trainer class in LVLM training when using Seq2SeqTrainingArguments to specify the model saving criteria (e.g., steps, epochs). When setting up the training with these parameters, the model saves without any issues. However, I encountered a problem where the model does not start saving when I implemented the following code to save the model at specific step counts or epochs. I would appreciate any advice on how to resolve this issue.
Implemented Code:
from transformers import TrainerCallback
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
class CustomModelSaveCallback(TrainerCallback):
def __init__(self):
self.trainer = None # No Trainer instance is set at initialization
def set_trainer(self, trainer_instance):
self.trainer = trainer_instance
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if state.global_step % (state.max_steps//args.save_total_limit) == 0:
output_dir = f"{args.output_dir}/checkpoint-{int(state.global_step)}"
self.trainer.save_model(output_dir)
print(f"Saving model to {args.output_dir} at step {state.global_step}")
Trainer has a save_model
function below:
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
if self.fsdp is not None:
if output_dir is None:
output_dir = self.args.output_dir
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
FullStateDictConfig,
StateDictType,
)
save_policy = FullStateDictConfig(offload_to_cpu=False, rank0_only=False)
# print(StateDictType.FULL_STATE_DICT)
# print(save_policy)
with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, save_policy):
cpu_state_dict = self.model.state_dict()
if self.args.should_save:
self._save(output_dir, state_dict=cpu_state_dict) # noqa
# Push to the Hub when `save_model` is called by the user.
if self.args.push_to_hub and not _internal_call:
self.push_to_hub(commit_message="Model save")
else:
super().save_model(output_dir, _internal_call)
Environment:
- torch 2.1.2+cu121
- transformers 4.29.0
- Machine: DGX H100
- Training way: FSDP
Regarding save_model
, given the ample machine performance, I am not implementing rank0_only or offloading to CPU.