IndexError on devices[0] when initializing a Trainer

Hi,
I’m relatively new to Hugging Face, and I’m facing an error I’m not able to debug when trying to Fine-tune a Vigogne model on my own data.
First of all some context:
I’m running everything in a Jupyter Notebook on AWS SageMaker (Instance ml.m5.2xlarge / ml.m5.4xlarge) with Python 3.

I’m then using a code snippet shared by the person who trained the Alpaca model in French I want to use :

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, LlamaTokenizer
# model shards of decapoda-research/llama-7b-hf are smaller, safer to use with Colab's RAM constraint
base_model_name_or_path = "decapoda-research/llama-7b-hf"
lora_model_name_or_path = "bofenghuang/vigogne-7b-chat"

load_8bit = True

tokenizer_class = LlamaTokenizer if "llama" in base_model_name_or_path else AutoTokenizer
tokenizer = tokenizer_class.from_pretrained(base_model_name_or_path, padding_side="right", use_fast=False)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

try:
    if torch.backends.mps.is_available():
        device = "mps"
except:
    pass

if device == "cuda":
    model = AutoModelForCausalLM.from_pretrained(
        base_model_name_or_path,
        load_in_8bit=load_8bit,
        torch_dtype=torch.float16,
        device_map="auto",
    )
    model = PeftModel.from_pretrained(
        model,
        lora_model_name_or_path,
        torch_dtype=torch.float16,
    )
elif device == "mps":
    model = AutoModelForCausalLM.from_pretrained(
        base_model_name_or_path,
        device_map={"": device},
        torch_dtype=torch.float16,
    )
    model = PeftModel.from_pretrained(
        model,
        lora_model_name_or_path,
        device_map={"": device},
        torch_dtype=torch.float16,
    )
else:
    model = AutoModelForCausalLM.from_pretrained(
        base_model_name_or_path, device_map={"": device}, low_cpu_mem_usage=True
    )
    model = PeftModel.from_pretrained(
        model,
        lora_model_name_or_path,
        device_map={"": device},
    )

if not load_8bit and device != "cpu":
    model.half()  # seems to fix bugs for some users.

model.eval()

if torch.__version__ >= "2":
    model = torch.compile(model)

Up to that point, everything is fine , I can chat with Vigogne and it seems to work well ! (a bit slow though)

then I load my own dataset and I tokenize the data => I get 2 columns: 1 with the inputs_ids from the questions called “input_ids” and 1 with the input_ids from the answer called “label”. I then try to fine tune my model with it:

import time
from transformers import TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    r=32, # Rank
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)
peft_model = get_peft_model(model, 
                            lora_config)
output_dir = f'./peft-dialogue-summary-training-{str(int(time.time()))}'

peft_training_args = TrainingArguments(
    output_dir=output_dir,
    auto_find_batch_size=True,
    learning_rate=1e-3, # Higher learning rate than full fine-tuning.
    num_train_epochs=1,
    logging_steps=1,
    max_steps=1    
)
    
peft_trainer = Trainer(
    model=peft_model,
    args=peft_training_args,
    train_dataset=tokenized_data,
)

And then I have my issue :

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-16-c4b1c566cc63> in <module>
     13     model=peft_model,
     14     args=peft_training_args,
---> 15     train_dataset=tokenized_data,
     16 )

/opt/conda/lib/python3.7/site-packages/transformers/trainer.py in __init__(self, model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)
    394                 self.is_model_parallel = True
    395             else:
--> 396                 self.is_model_parallel = self.args.device != torch.device(devices[0])
    397 
    398             # warn users

IndexError: list index out of range

Any idea on how to fix this ?

Many thanks !