Transformers Trainer + Accelerate FSDP: How do I load my model from a checkpoint?

Hi all,

I’ve fine-tuned a Llama2 model using the transformers Trainer class, plus accelerate and FSDP, with a sharded state dict. Now my checkpoint directories all have the model’s state dict sharded across multiple .distcp files; how do I open them, or convert them to a format I can open with .from_pretrained()? I’ve not found documentation on this anywhere. Any help would be greatly appreciated!

Thank you so much!

3 Likes

So I run into the same issues and after many many google searches I found llama_receipe have a pull request that fixed the issue and provide a doc/script for it. I haven’t tested it yet, but it seems helpful.
Specifically the command should be
python -m llama_recipes.inference.checkpoint_converter_fsdp_hf --fsdp_checkpoint_path PATH/to/FSDP/Checkpoints --consolidated_model_path PATH/to/save/checkpoints --HF_model_path_or_name PATH/or/HF/model_name

1 Like

Thank you so much @mqo ! That does fix it! :slight_smile: Also to follow up with some more information for anyone else stumbling across this:

Doing this yourself

You can also do this in a jupyter notebook without the llama_recipes function, but replicating what they do - that can give you a little bit more control, and you can check that model outputs are what you expect them to be before you save the consolidated model. The steps are:

  1. Load a model from a working checkpoint, e.g.
model = AutoModelForCausalLM.from_pretrained(
    modelpath,
    torch_dtype=torch.bfloat16,
    device_map="cpu",
)
tokenizer = AutoTokenizer.from_pretrained(modelpath)

Important things to note here: First, this works entirely on CPU! (But you can do it on GPU too, of course.) Second, make sure you specify the torch dtype - otherwise you can end up with an FP32 checkpoint, when your original model was BF16! Third, the model must be exactly the same as the one you want to recover. If you resized the embedding (e.g. because you added special tokens to the tokenizer) then you must load a model here that has the resized embedding! (You should be able to load the original model, resize, and immediately save to disk without training to get such a checkpoint, then use that in the code above.)
2. Load distcp checkpoint:

import torch.distributed._shard.checkpoint as dist_cp
state_dict = {
        "model": model.state_dict()
    }
dist_cp.load_state_dict(
                state_dict=state_dict,
                storage_reader= dist_cpFileSystemReader(distcp_checkpoint_path),
                no_dist=True,
            )
  1. Put the weights into the model you loaded in the first step:
model.load_state_dict(state_dict["model"])
  1. Save the model using model.save_pretrained() as you usually would.

Avoiding this altogether:

The better way is to save the model properly in the first place. That’s described for instance here. Use the following code after training:

if trainer.is_fsdp_enabled:
    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model(output_dir)