I’ve trained an OLMo model from a revision with the state dict sharded across multiple .distcp files and the model weights in rank.pt files. I’m trying to run inference with this checkpoint. How can I load this model into a .from_pretrained format?
I’m trying to use this code but getting a lot of key mismatch/missing key errors:
import fire
import torch.distributed.checkpoint as dist_cp
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AutoModel
def load_sharded_model_single_gpu(model, model_path):
state_dict = {
"model": model.state_dict()
}
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader=dist_cp.FileSystemReader(model_path),
no_dist=True,
)
result = model.load_state_dict(state_dict["model"])
print(f"Sharded state checkpoint loaded from {model_path}")
print(result)
return model
def convert_checkpoint(hf_model: str, fsdp_model_path: str, output_path: str):
'''
hf_model: transformers path.
fsdp_model_path: path to the fsdp checkpoint, for example `/x/checkpoint-xxx/pytorch_model_x`
output_path: output path to save the converted checkpoint
'''
config = AutoConfig.from_pretrained(hf_model, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(hf_model, trust_remote_code=True)
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
model = load_sharded_model_single_gpu(model, fsdp_model_path)
model.save_pretrained(output_path, max_shard_size="10GB")
tokenizer.save_pretrained(output_path)
if __name__ == "__main__":
fire.Fire(convert_checkpoint)