How do I load a trained checkpoint model?

I’ve trained an OLMo model from a revision with the state dict sharded across multiple .distcp files and the model weights in rank.pt files. I’m trying to run inference with this checkpoint. How can I load this model into a .from_pretrained format?

I’m trying to use this code but getting a lot of key mismatch/missing key errors:

import fire

import torch.distributed.checkpoint as dist_cp

from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AutoModel

def load_sharded_model_single_gpu(model, model_path):
    
    state_dict = {
        "model": model.state_dict()
    }
    
    dist_cp.load_state_dict(
                state_dict=state_dict,
                storage_reader=dist_cp.FileSystemReader(model_path),
                no_dist=True,
            )
    
    result = model.load_state_dict(state_dict["model"])
    
    print(f"Sharded state checkpoint loaded from {model_path}")
    print(result)
    return model

def convert_checkpoint(hf_model: str, fsdp_model_path: str, output_path: str):
    '''
    hf_model: transformers path.
    fsdp_model_path: path to the fsdp checkpoint, for example `/x/checkpoint-xxx/pytorch_model_x`
    output_path: output path to save the converted checkpoint
    '''
    config = AutoConfig.from_pretrained(hf_model, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(hf_model, trust_remote_code=True)
    model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
    model = load_sharded_model_single_gpu(model, fsdp_model_path)
    model.save_pretrained(output_path, max_shard_size="10GB")
    tokenizer.save_pretrained(output_path)

if __name__ == "__main__":
    fire.Fire(convert_checkpoint)
1 Like

I think that method is correct, but there seem to be reports of tensor mismatch issues when training with FSDP.