8 bit precision error

Hi

I am finetuning Mistral 7B model for multilabel classification. I am getting the following error -

ValueError: You can’t train a model that has been loaded in 8-bit precision on a different device than the one you’re training on. Make sure you loaded the model on the correct device using for example `device_map={‘’:torch.cuda.current_device() or device_map={‘’:torch.xpu.current_device()}
I0000 00:00:1711804411.121734 181893 cpu_client.cc:373] TfrtCpuClient destroyed.

Here is my code -

qunatization config

quantization_config = BitsAndBytesConfig(
load_in_4bit = True, # enable 4-bit quantization
bnb_4bit_quant_type = ‘nf4’, # information theoretically optimal dtype for normally distributed weights
bnb_4bit_use_double_quant = True, # quantize quantized weights //insert xzibit meme
bnb_4bit_compute_dtype = torch.bfloat16 # optimized fp format for ML
)

lora_config = LoraConfig(
r = 16, # the dimension of the low-rank matrices
lora_alpha = 8, # scaling factor for LoRA activations vs pre-trained weight activations
target_modules = [‘q_proj’, ‘k_proj’, ‘v_proj’, ‘o_proj’],
lora_dropout = 0.05, # dropout probability of the LoRA layers
bias = ‘none’, # wether to train bias weights, set to ‘none’ for attention layers
task_type = ‘SEQ_CLS’
)

from accelerate import PartialState

model = AutoModelForSequenceClassification.from_pretrained(
model_name,
quantization_config=quantization_config,
num_labels=len(le.classes_),
device_map={“”: PartialState().process_index},
)

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
model.config.pad_token_id = tokenizer.pad_token_id

define training args

training_args = TrainingArguments(
output_dir = ‘multilabel_classification’,
learning_rate = 1e-4,
per_device_train_batch_size = 8, # tested with 16gb gpu ram
per_device_eval_batch_size = 8,
num_train_epochs = 10,
weight_decay = 0.01,
evaluation_strategy = ‘epoch’,
save_strategy = ‘epoch’,
load_best_model_at_end = True
)

label_weights = 1 - labels_encoded/labels_encoded.sum()

train

trainer = CustomTrainer(
model = model,
args = training_args,
train_dataset = tokenized_ds[‘train’],
eval_dataset = tokenized_ds[‘val’],
tokenizer = tokenizer,
data_collator = functools.partial(collate_fn, tokenizer=tokenizer),
compute_metrics = compute_metrics,
label_weights = torch.tensor(label_weights, device=model.device)
)

trainer.train()

I have tried many things discussed on google but nothing worked. Any help would be appreciated.