Key errors when trying to load an accelerate-FSDP model checkpoint

Hi,

I have been training a pythia-2.8b model, using the DPOTrainer in Huggingface TRL. I am training a LORA adaptor on top of the reference SFT model. In order to speed it up and address memory bottlenecks, I used fsdp and launched my code using

accelerate launch --config_file=fsdp.yaml --num_processes=4 run_dpo.py

After training, I observed my checkpoints to be of the following form

optimizer_0/ pytorch_model_fsdp_0/ rng_state_0.pth rng_state_1.pth scheduler.pt trainer_state.json

When I try to load this checkpoint using the following piece of code, I keep getting Key-errors, where some of the keys in the model are missing in the saved checkpoints. This behavior was also observed when training the facebook/opt-350m model, so it seems like I am missing something here. The code to load the fsdp checkpoint is provided below:

import torch.distributed.checkpoint as dist_cp
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AutoModel
from peft import LoraConfig, get_peft_model

def load_sharded_model_single_gpu(model, model_path):
    
    state_dict = {
        "model": model.state_dict()
    }
    
    dist_cp.load_state_dict(
                state_dict=state_dict,
                storage_reader=dist_cp.FileSystemReader(model_path),
                no_dist=True,
            )
    
    result = model.load_state_dict(state_dict["model"])
    
    print(f"Sharded state checkpoint loaded from {model_path}")
    print(result)
    return model

def convert_checkpoint(hf_model: str, fsdp_model_path: str, output_path: str):
    '''
    hf_model: transformers path.
    fsdp_model_path: path to the fsdp checkpoint, for example `/x/checkpoint-xxx/pytorch_model_x`
    output_path: output path to save the converted checkpoint
    '''
    config = AutoConfig.from_pretrained(hf_model, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(hf_model, trust_remote_code=True)
    model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)

    # peft_config = LoraConfig(
    #     r=16,
    #     lora_alpha=32,
    #     lora_dropout=0.05,
    #     bias="none",
    #     task_type="CAUSAL_LM",
    # )
    # model = get_peft_model(model, peft_config, adapter_name='__train__')

    model = load_sharded_model_single_gpu(model, fsdp_model_path)
    model.save_pretrained(output_path, max_shard_size="10GB")
    tokenizer.save_pretrained(output_path)

Can I understand why I keep getting the key errors, when trying to load PEFT-fsdp checkpoints? I have observed that there are no errors when loading the entire model from fsdp checkpoints (i.e no PEFT). Also could you enlighten me on how to effectively save and load models when using FSDP. I have tried using accelerator.load_state(<ckpt_path>) which was suggested in the documentation, but this complains that the pytorch_model_fsdp.bin file is missing from the checkpoint, which was never saved by the DPOTrainer. I am also providing my run_dpo.py code, for the sake of completion

from collections import defaultdict
from peft import PeftModel, LoraConfig, get_peft_model
from trl import DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from accelerate import Accelerator
from data import get_tldr
from datasets import Dataset
import torch
import numpy as np
import random
import wandb
import datetime
import os, sys, argparse

def to_hf_format(data):
    new_data = defaultdict(list)
    for ix, k in enumerate(data.keys()):
        for i, j in data[k]['pairs']:
            new_data['prompt'].append(k)
            new_data['chosen'].append(data[k]['responses'][i])
            new_data['rejected'].append(data[k]['responses'][j])
    return Dataset.from_dict(new_data)

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="My Script Description")
    parser.add_argument("--seed", type=int, default=42, help="Seed for random number generation")
    args = parser.parse_args()

    # Set seed
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    datapath = os.path.join(".", "data/tldr/comparisons")

    train_data = get_tldr(datapath, 'train')
    test_data = get_tldr(datapath, 'test')

    train_data, test_data = map(to_hf_format, [train_data, test_data])

    ######################### ARGS ###########################
    model_name = "EleutherAI/pythia-2.8b"
    sft_ckpt_name = "rmrafailov/TLDR-Pythia2.8B-SFT"
    cache_dir = "/data/yaswanth_chittepu/Data/dpo_cache/"
    batch_size = 16
    gradient_accumulation_steps = 16
    gradient_checkpoint = True
    ##########################################################

    torch.distributed.init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=36000))

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(sft_ckpt_name, \
                cache_dir = cache_dir, device_map={"": Accelerator().local_process_index},)

    if(gradient_checkpoint):
        # Needed when we use gradient checkpointing
        # Otherwise complains that no gradients to compute 
        if hasattr(model, "enable_input_require_grads"):
            model.enable_input_require_grads()
        else:
            def make_inputs_require_grad(module, input, output):
                output.requires_grad_(True)
            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

    peft_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, peft_config, adapter_name='__train__')

    out_dir = os.path.join("/data/yaswanth_chittepu/Data/robust_dpo/dpo/", f"seed-{args.seed}")
    os.makedirs(out_dir, exist_ok=True)

    training_args = TrainingArguments(
        output_dir=out_dir,
        run_name=f"pythia-2.8b-tldr-dpo-seed-{args.seed}",
        report_to="wandb",  # this tells the Trainer to log the metrics to W&B
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=8,
        bf16=True,
        learning_rate=1e-4,
        lr_scheduler_type="cosine",
        warmup_ratio = 0.1,
        gradient_accumulation_steps=gradient_accumulation_steps,
        gradient_checkpointing=gradient_checkpoint,
        gradient_checkpointing_kwargs={"use_reentrant":False},
        evaluation_strategy="steps",
        eval_steps=10,
        num_train_epochs=4,
        # logging strategies 
        logging_strategy="steps",
        logging_steps=1,
        save_strategy="steps",
        save_steps = 20,
        max_grad_norm=1.,
        remove_unused_columns=False,
        seed=args.seed,
    )

    # Initialize the trainer, without a ref_model param.
    dpo_trainer = DPOTrainer(
        model,
        ref_model=None,
        beta=0.1,
        train_dataset=train_data,
        eval_dataset=test_data,
        tokenizer=tokenizer,
        max_length=512,
        max_prompt_length=256,
        args=training_args,
    )

    dpo_trainer.train()
    
    final_dir = os.path.join(training_args.output_dir, "final")
    os.makedirs(final_dir, exist_ok=True)
    dpo_trainer.save_model(final_dir)