Transformers Trainer + Accelerate FSDP: How do I load my model from a checkpoint?

Thank you so much @mqo ! That does fix it! :slight_smile: Also to follow up with some more information for anyone else stumbling across this:

Doing this yourself

You can also do this in a jupyter notebook without the llama_recipes function, but replicating what they do - that can give you a little bit more control, and you can check that model outputs are what you expect them to be before you save the consolidated model. The steps are:

  1. Load a model from a working checkpoint, e.g.
model = AutoModelForCausalLM.from_pretrained(
    modelpath,
    torch_dtype=torch.bfloat16,
    device_map="cpu",
)
tokenizer = AutoTokenizer.from_pretrained(modelpath)

Important things to note here: First, this works entirely on CPU! (But you can do it on GPU too, of course.) Second, make sure you specify the torch dtype - otherwise you can end up with an FP32 checkpoint, when your original model was BF16! Third, the model must be exactly the same as the one you want to recover. If you resized the embedding (e.g. because you added special tokens to the tokenizer) then you must load a model here that has the resized embedding! (You should be able to load the original model, resize, and immediately save to disk without training to get such a checkpoint, then use that in the code above.)
2. Load distcp checkpoint:

import torch.distributed._shard.checkpoint as dist_cp
state_dict = {
        "model": model.state_dict()
    }
dist_cp.load_state_dict(
                state_dict=state_dict,
                storage_reader= dist_cpFileSystemReader(distcp_checkpoint_path),
                no_dist=True,
            )
  1. Put the weights into the model you loaded in the first step:
model.load_state_dict(state_dict["model"])
  1. Save the model using model.save_pretrained() as you usually would.

Avoiding this altogether:

The better way is to save the model properly in the first place. That’s described for instance here. Use the following code after training:

if trainer.is_fsdp_enabled:
    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model(output_dir)
3 Likes