How to load a checkpoint model with SHARDED_STATE_DICT?

How to load a checkpoint model with SHARDED_STATE_DICT?
I have a checkpoint which is place in a folder pytorch_model_0, which contains multiple distcp files.

1 Like

Ihave the same question
did you find out something ?

I have found the solution for llama or you can edit for other models

you can use this :
llama-recipes/docs/inference.md at main · facebookresearch/llama-recipes · GitHub

i have edited some files inside it to load mistral instead of llama

Yep. I used the same solution with you. I tried to find this code the day you ask me but I can not remember where it is. So glad you find it yourself.

1 Like

I will post my code here:

import fire

import torch.distributed.checkpoint as dist_cp

from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AutoModel

def load_sharded_model_single_gpu(model, model_path):
    
    state_dict = {
        "model": model.state_dict()
    }
    
    dist_cp.load_state_dict(
                state_dict=state_dict,
                storage_reader=dist_cp.FileSystemReader(model_path),
                no_dist=True,
            )
    
    result = model.load_state_dict(state_dict["model"])
    
    print(f"Sharded state checkpoint loaded from {model_path}")
    print(result)
    return model

def convert_checkpoint(hf_model: str, fsdp_model_path: str, output_path: str):
    '''
    hf_model: transformers path.
    fsdp_model_path: path to the fsdp checkpoint, for example `/x/checkpoint-xxx/pytorch_model_x`
    output_path: output path to save the converted checkpoint
    '''
    config = AutoConfig.from_pretrained(hf_model, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(hf_model, trust_remote_code=True)
    model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
    model = load_sharded_model_single_gpu(model, fsdp_model_path)
    model.save_pretrained(output_path, max_shard_size="10GB")
    tokenizer.save_pretrained(output_path)

if __name__ == "__main__":
    fire.Fire(convert_checkpoint)
1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.