Getting the error: AssertionError: Non-root FSDP instance's `_is_root` should not have been set yet or should have been set to `False` while Finetuning GPT2 model

Hi, I am getting the following error while fine-tuning GPT-2 using accelerate[with FSDP strategy] on my custom datasets.

My code:

import pandas as pd
import os
pd.set_option('display.max_column', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_seq_items', None)
pd.set_option('display.max_colwidth', 500)
pd.set_option('expand_frame_repr', True)
import time
from tqdm import tqdm
tqdm.pandas()
import wandb, os

import transformers
from datetime import datetime
from peft import prepare_model_for_kbit_training

from peft import LoraConfig, get_peft_model
from accelerate import FullyShardedDataParallelPlugin, Accelerator
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from accelerate import PartialState
from peft import PeftModel
from datasets import load_dataset



def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )
    
def formatting_func(x):
    text=create_prompt(str(x['column1']))
    return text
    
def create_prompt(input1:str):                       
    user_prompt = f"""
    
    Complete the given input sentence.
    
    """
    prompt = f"""
        [INST] {user_prompt} [/INST]
        Output: {input1}
        """
    return prompt

from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
gpt_model = GPT2LMHeadModel.from_pretrained('gpt2',device_map={"": PartialState().process_index},)


def generate_and_tokenize_prompt2(prompt):
    result = tokenizer(
        formatting_func(prompt),
        truncation=True,
        max_length=512,
        padding="max_length",
        return_tensors='pt'
    )
    result["labels"] = result["input_ids"].clone().detach()
    return result

train_dataset=pd.read_csv(r'./data/sample_train.csv')
eval_dataset=pd.read_csv(r'./data/sample_test.csv')

tokenized_train_dataset = train_dataset.progress_apply(lambda x: generate_and_tokenize_prompt2(x), axis=1).values


tokenized_val_dataset = eval_dataset.progress_apply(lambda x: generate_and_tokenize_prompt2(x), axis=1).values

print_trainable_parameters(gpt_model)


project = "output"
base_model_name = "gpt-2-training"
output_dir = "./output/" + base_model_name
run_name = '1'

# training using accelerate
accelerator = Accelerator()
trainer = accelerator.prepare(transformers.Trainer(
    model=gpt_model,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_val_dataset,
    args=transformers.TrainingArguments(
        output_dir=output_dir,
        warmup_steps=1,
        per_device_train_batch_size=5,
        gradient_accumulation_steps=1,
        gradient_checkpointing=False,
        gradient_checkpointing_kwargs={'use_reentrant':False},
        max_steps=20,
        learning_rate=2.5e-5, # Want a small lr for finetuning
        #optim="paged_adamw_8bit",
        logging_steps=2,              # When to start reporting loss
        logging_dir="./logs",        # Directory for storing logs
        save_strategy="steps",       # Save the model checkpoint every logging step
        save_steps=5,                # Save checkpoints every 100 steps
        evaluation_strategy="steps", # Evaluate the model every logging step
        eval_steps=20,               # Evaluate and save checkpoints every 100 steps
        do_eval=True,                # Perform evaluation at the end of training
        per_device_eval_batch_size=5,
        # fsdp='full_shard',
        #dataloader_num_workers =1,
        #predict_with_generate=True,
        report_to="wandb",           # Comment this out if you don't want to use weights & baises
        run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}"          # Name of the W&B run (optional)
    ),
    #data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
))

# model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

accelerate config:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_transformer_layer_cls_to_wrap: GPT2Block
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Execution command:

accelerate launch --config-file <config_file> <python_file>.py

Library Version:

transformers==4.41.0
accelerate==0.30.1

Could anyone please help me to solve it ?
Thanks in advance.