FSDP training not saving the best checkpoint and load from checkpoint fails

Hi there!

I followed training a T5 model with FSDP on Sagemaker from the example https://github.com/huggingface/notebooks/blob/main/sagemaker/25_pytorch_fsdp_model_parallelism/scripts/run_clm.py

I noticed that checkpointing is not done with save_strategy="no". Is it intentional?
( https://github.com/huggingface/notebooks/blob/main/sagemaker/25_pytorch_fsdp_model_parallelism/scripts/run_clm.py#L93)

In my training I changed it to save_strategy="steps" and noticed two issues

  1. Best checkpoints based on min validation loss is not saved. If I set the limit to 2 for e.g., the last 2 checkpoints are saved
  2. I was not able to load the trained model from checkpoint and got the error which is mentioned elsewhere in issues RuntimeError: Trying to resize storage that is not resizable. This does not happen if I want to load the final model. But it makes training hard since I need to know when to stop training so that I have the final model withe the minimum loss saved. I tried with different versions
PyTorch 1.13
Transformers 4.26

and

PyTorch 2.0.0
Transformers 4.28.1

and see the same issue with loading a model from checkpoint. Please note that I did not change the size of the model embeddings since that is what I thought the error was related to.

As described in the issue (Transformers Trainer + Accelerate FSDP: How do I load my model from a checkpoint?) I saved model in the following way and should be able to load it

def safe_save_model_for_hf_trainer(trainer: Trainer, tokenizer: AutoTokenizer, output_dir: str):
    """Helper method to save model for HF Trainer."""
    # see: https://github.com/tatsu-lab/stanford_alpaca/issues/65
    from torch.distributed.fsdp import (
        FullyShardedDataParallel as FSDP,
        FullStateDictConfig,
        StateDictType,
    )

    model = trainer.model
    save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
    with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
        cpu_state_dict = model.state_dict()
    if trainer.args.should_save:
        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa
        tokenizer.save_pretrained(output_dir)

Would appreciate any pointers

Thank you!