Loading a peft model which is saved on multiple nodes using sharded_state_dict?

Hello,

I have trained a model using peft, fsdp and accelerate. I have used following script:

accelerator = Accelerator()
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B-Instruct"
)
Lora_config = LoraConfig(r=32,lora_alpha=64,lora_dropout=0.05, bias="none")
model = get_peft_model(model, Lora_config)

tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
)
tokenizer.pad_token = tokenizer.eos_token

my_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm =False)

training_args = TrainingArguments(
    output_dir= './finetuned_llama3.1_8B',
    do_eval=True,
   ....
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=my_collator,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer
)

trainer.train()

The accelerate config is as follows:

compute_environment: LOCAL_MACHINE
debug: true
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
  fsdp_activation_checkpointing: true
  fsdp_auto_wrap_policy: SIZE_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: true
  fsdp_min_num_params: 10000000
  fsdp_offload_params: true
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: false
machine_rank: 0
main_process_ip: *************
main_process_port: 29500
main_training_function: main
mixed_precision: fp16
num_machines: 7
num_processes: 28
rdzv_backend: c10d
same_network: false
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

While training I ran the training script on each node through accelerate launch. The model got saved in shards across each node.
On each node, under saved directory, I can see following directory structure:

.
β”œβ”€β”€ optimizer_0
β”‚   β”œβ”€β”€ __0_0.distcp
β”‚   β”œβ”€β”€ __1_0.distcp
β”‚   β”œβ”€β”€ __2_0.distcp
β”‚   └── __3_0.distcp
β”œβ”€β”€ pytorch_model_fsdp_0
β”‚   β”œβ”€β”€ __0_0.distcp
β”‚   β”œβ”€β”€ __1_0.distcp
β”‚   β”œβ”€β”€ __2_0.distcp
β”‚   └── __3_0.distcp
β”œβ”€β”€ rng_state_0.pth
β”œβ”€β”€ rng_state_1.pth
β”œβ”€β”€ rng_state_2.pth
β”œβ”€β”€ rng_state_3.pth
β”œβ”€β”€ scheduler.pt
└── trainer_state.json

on the next node, I can see the shards being saved as

.
β”œβ”€β”€ optimizer_0
β”‚   β”œβ”€β”€ __4_0.distcp
β”‚   β”œβ”€β”€ __5_0.distcp
β”‚   β”œβ”€β”€ __6_0.distcp
β”‚   └── __7_0.distcp
β”œβ”€β”€ pytorch_model_fsdp_0
β”‚   β”œβ”€β”€ __4_0.distcp
β”‚   β”œβ”€β”€ __5_0.distcp
β”‚   β”œβ”€β”€ __6_0.distcp
β”‚   └── __7_0.distcp
β”œβ”€β”€ rng_state_4.pth
β”œβ”€β”€ rng_state_5.pth
β”œβ”€β”€ rng_state_6.pth
β”œβ”€β”€ rng_state_7.pth
β”œβ”€β”€ scheduler.pt
└── trainer_state.json

and so on…

Now I want to load this peft model and do an inference on it. If it was saved on a single node, I would have done something like :

base_path="meta-llama/Meta-Llama-3.1-8B-Instruct"    # input: base model
adapter_path="./finetuned_llama3.1_8B"     # input: adapters

base_model = AutoModelForCausalLM.from_pretrained(
     base_path
)
model = PeftModel.from_pretrained(base_model, adapter_path)
model = model.merge_and_unload()

But since my adaptors are saved across multiple nodes, how to load it? Like keep the script on each node and launch using accelerate ? Is there any way I can save the model at a single place so that I can call it easily for inference in future on a single node?

Thanks.