FSDP with Trainer class: AlgorithmError: ValueError('Cannot flatten integer dtype tensors'), exit code: 1

Hello everyone, I’m trying to use FSDP with for distributing the fine-tuning with QLoRa and PEFT on 4 GPUs (single node), by using the Trainer class.

model: meta-llama/Meta-Llama-3-8B-Instruct

This is my training script:

from accelerate import Accelerator
from huggingface_hub import login
from peft import AutoPeftModelForCausalLM, LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed
import transformers

def train_fn(
        model_name,
        train_ds,
        test_ds=None,
        lora_r=8,
        lora_alpha=16,
        lora_dropout=0.1,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        auto_find_batch_size=False,
        gradient_accumulation_steps=1,
        learning_rate=2e-4,
        num_train_epochs=1,
        fsdp="",
        fsdp_config=None,
        chunk_size=2048,
        gradient_checkpointing=False,
        merge_weights=False,
        seed=42,
        token=None
):

    set_seed(seed)

    accelerator = Accelerator()

    if token is not None:
        login(token=token)

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Set Tokenizer pad Token
    tokenizer.pad_token = tokenizer.eos_token

    with accelerator.main_process_first():
        # tokenize
        lm_train_dataset = train_ds.map(
            lambda sample: tokenizer(sample["text"]), remove_columns=list(train_ds.features)
        )

        print(f"Total number of train samples: {len(lm_train_dataset)}")

        if test_ds is not None:
            lm_test_dataset = test_ds.map(
                lambda sample: tokenizer(sample["text"]), remove_columns=list(test_ds.features)
            )

            print(f"Total number of test samples: {len(lm_test_dataset)}")
        else:
            lm_test_dataset = None

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        quant_storage_dtype=torch.bfloat16
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        quantization_config=bnb_config,
        attn_implementation="flash_attention_2",
        use_cache=False if gradient_checkpointing else True,
        cache_dir="/tmp/.cache"
    )

    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=gradient_checkpointing)

    # get lora target modules
    modules = find_all_linear_names(model)
    print(f"Found {len(modules)} modules to quantize: {modules}")

    config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type="CAUSAL_LM"
    )

    model = get_peft_model(model, config)
    print_trainable_parameters(model)

    trainer = transformers.Trainer(
        model=model,
        train_dataset=lm_train_dataset,
        eval_dataset=lm_test_dataset if lm_test_dataset is not None else None,
        args=transformers.TrainingArguments(
            per_device_train_batch_size=per_device_train_batch_size,
            per_device_eval_batch_size=per_device_eval_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            gradient_checkpointing=gradient_checkpointing,
            logging_strategy="steps",
            logging_steps=10,
            num_train_epochs=num_train_epochs,
            learning_rate=learning_rate,
            bf16=True,
            fsdp=fsdp,
            fsdp_config=fsdp_config,
            save_strategy="no",
            output_dir="outputs",
        ),
        data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
    )

    trainer.train()

    if trainer.is_fsdp_enabled:
        trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

    if merge_weights:
        output_dir = "/tmp/model"

        # merge adapter weights with base model and save
        # save int 4 model
        trainer.model.save_pretrained(output_dir, safe_serialization=False)
        # clear memory
        del model
        del trainer

        torch.cuda.empty_cache()

        # load PEFT model in fp16
        model = AutoPeftModelForCausalLM.from_pretrained(
            output_dir,
            low_cpu_mem_usage=True,
            torch_dtype=torch.float16,
            cache_dir="/tmp/.cache"
        )

        # Merge LoRA and base model and save
        model = model.merge_and_unload()
        model.save_pretrained(
            "/opt/ml/model", safe_serialization=True, max_shard_size="2GB"
        )
    else:
        trainer.model.save_pretrained("/opt/ml/model", safe_serialization=True)

    tokenizer.save_pretrained("/opt/ml/model")

My requirements:

transformers==4.40.0
peft==0.10.0
accelerate==0.29.3
bitsandbytes==0.43.1
evaluate==0.4.1
safetensors>=0.4.3
tokenizers>=0.19.1
py7zr

FSDP paramters:

fsdp="full_shard auto_wrap offload",
fsdp_config={
      "backward_prefetch": "backward_pre",
      "forward_prefetch": "false",
      "use_orig_params": "false"
  },

This is an example of the the train_ds before the map:

Dataset({
    features: ['text'],
    num_rows: 10800
})

train_ds[0]
{'text': '<|start_header_id|>user<|end_header_id|>This is the question<|eot_id|><|start_header_id|>assistant<|end_header_id|>This is the answer<|eot_id|><|end_of_text|>'}

This is lm_train_dataset after the map:

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 10800
})

When I’m running the script with accelerate launch train.py I’m receiving the following exception:

File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/flat_param.py", line 720, in _validate_tensors_to_flatten
    raise ValueError("Cannot flatten integer dtype tensors")

I’m not able to understand how and what I should change to make it work. I’m only able to find examples with SFTTrainer, which I don’t want to adopt. Can you please help me?