I’ve fine-tuned a Llama2 model using the transformers Trainer class, plus accelerate and FSDP, with a sharded state dict. Now my checkpoint directories all have the model’s state dict sharded across multiple .distcp files; how do I open them, or convert them to a format I can open with .from_pretrained()? I’ve not found documentation on this anywhere. Any help would be greatly appreciated!
So I run into the same issues and after many many google searches I found llama_receipe have a pull request that fixed the issue and provide a doc/script for it. I haven’t tested it yet, but it seems helpful.
Specifically the command should be python -m llama_recipes.inference.checkpoint_converter_fsdp_hf --fsdp_checkpoint_path PATH/to/FSDP/Checkpoints --consolidated_model_path PATH/to/save/checkpoints --HF_model_path_or_name PATH/or/HF/model_name
Thank you so much @mqo ! That does fix it! Also to follow up with some more information for anyone else stumbling across this:
Doing this yourself
You can also do this in a jupyter notebook without the llama_recipes function, but replicating what they do - that can give you a little bit more control, and you can check that model outputs are what you expect them to be before you save the consolidated model. The steps are:
Load a model from a working checkpoint, e.g.
model = AutoModelForCausalLM.from_pretrained(
modelpath,
torch_dtype=torch.bfloat16,
device_map="cpu",
)
tokenizer = AutoTokenizer.from_pretrained(modelpath)
Important things to note here: First, this works entirely on CPU! (But you can do it on GPU too, of course.) Second, make sure you specify the torch dtype - otherwise you can end up with an FP32 checkpoint, when your original model was BF16! Third, the model must be exactly the same as the one you want to recover. If you resized the embedding (e.g. because you added special tokens to the tokenizer) then you must load a model here that has the resized embedding! (You should be able to load the original model, resize, and immediately save to disk without training to get such a checkpoint, then use that in the code above.)
2. Load distcp checkpoint: