LoRA Finetuning RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

Hello everyone,
I am trying to fine-tune a Llama 3.1 8B Instruct Model using LoRA. I would like to use multiple GPUs, but I am getting the following error.

Traceback (most recent call last):                                                                                                                               
  File "/home/user/s25/finetune_model_LoRA.py", line 68, in <module>                                                                      
    trainer.train()                                                                                                                                              
    ~~~~~~~~~~~~~^^                                                                                                                                              
  File "/local/home/user/miniforge3/envs/project/lib/python3.13/site-packages/transformers/trainer.py", line 2240, in train                       
    return inner_training_loop(                                                                                                                                  
        args=args,                                                                                                                                               
    ...<2 lines>...                                                                                                                                              
        ignore_keys_for_eval=ignore_keys_for_eval,                                                                                                               
    )                                                                                                                                                            
  File "/local/home/user/miniforge3/envs/project/lib/python3.13/site-packages/transformers/trainer.py", line 2555, in _inner_training_loop        
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)                                                                                         
  File "/local/home/user/miniforge3/envs/project/lib/python3.13/site-packages/trl/trainer/sft_trainer.py", line 733, in training_step             
    return super().training_step(*args, **kwargs)                                                                                                                
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^                                                                                                                
  File "/local/home/user/miniforge3/envs/project/lib/python3.13/site-packages/transformers/trainer.py", line 3745, in training_step               
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)                                                                               
  File "/local/home/user/miniforge3/envs/project/lib/python3.13/site-packages/trl/trainer/sft_trainer.py", line 687, in compute_loss              
    (loss, outputs) = super().compute_loss(                                                                                                                      
                      ~~~~~~~~~~~~~~~~~~~~^                                                                                                                      
        model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch                                                                                
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                
    )                                                                                                                                                            
    ^                                                                                                                                                            
  File "/local/home/user/miniforge3/envs/project/lib/python3.13/site-packages/transformers/trainer.py", line 3810, in compute_loss                
    outputs = model(**inputs)   
  File "/local/home/user/miniforge3/envs/project/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl       
    return self._call_impl(*args, **kwargs)                                                                                                                      
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^                                                                                                                      
  File "/local/home/user/miniforge3/envs/project/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl               
    return forward_call(*args, **kwargs)                                                                                                                         
  File "/local/home/user/miniforge3/envs/project/lib/python3.13/site-packages/accelerate/utils/operations.py", line 818, in forward               
    return model_forward(*args, **kwargs)                                                                                                                        
  File "/local/home/user/miniforge3/envs/project/lib/python3.13/site-packages/accelerate/utils/operations.py", line 806, in __call__              
    return convert_to_fp32(self.model_forward(*args, **kwargs))                                                                                                  
                           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^                                                                                                   
  File "/local/home/user/miniforge3/envs/project/lib/python3.13/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast          
    return func(*args, **kwargs)                                                                                                                                 
  File "/local/home/user/miniforge3/envs/project/lib/python3.13/site-packages/peft/peft_model.py", line 1757, in forward                          
    return self.base_model(                                                                                                                                      
           ~~~~~~~~~~~~~~~^                                                                                                                                      
        input_ids=input_ids,                                                                                                                                     
        ^^^^^^^^^^^^^^^^^^^^                                                                                                                                     
    ...<6 lines>...                                                                                                                                              
        **kwargs,                                                                                                                                                
        ^^^^^^^^^                                                                                                                                                
    )                                                                                                                                                            
    ^                                                                                                                                                            
  File "/local/home/user/miniforge3/envs/project/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl       
    return self._call_impl(*args, **kwargs)                                                                                                                      
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^                                                                                                                      
  File "/local/home/user/miniforge3/envs/project/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl               
    return forward_call(*args, **kwargs)                                                                                                                         
  File "/local/home/user/miniforge3/envs/project/lib/python3.13/site-packages/peft/tuners/tuners_utils.py", line 193, in forward                  
    return self.model.forward(*args, **kwargs)                                                                                                                   
           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^                                                                                                                   
  File "/local/home/user/miniforge3/envs/project/lib/python3.13/site-packages/accelerate/hooks.py", line 175, in new_forward                      
    output = module._old_forward(*args, **kwargs)                                                                                                                
  File "/local/home/user/miniforge3/envs/project/lib/python3.13/site-packages/transformers/utils/generic.py", line 969, in wrapper
    output = func(self, *args, **kwargs)   
  File "/local/home/user/miniforge3/envs/project/lib/python3.13/site-packages/transformers/models/llama/modeling_llama.py", line 708, in forward
    loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  File "/local/home/user/miniforge3/envs/project/lib/python3.13/site-packages/transformers/loss/loss_utils.py", line 64, in ForCausalLMLoss
    loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
  File "/local/home/user/miniforge3/envs/project/lib/python3.13/site-packages/transformers/loss/loss_utils.py", line 38, in fixed_cross_entropy
    loss = loss / num_items_in_batch                                            
           ~~~~~^~~~~~~~~~~~~~~~~~~~                                            
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

I use the following script.

# Imports
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, DataCollatorForLanguageModeling, BitsAndBytesConfig
from peft import LoraConfig
from huggingface_hub import login
from datasets import load_dataset
from dotenv import load_dotenv
from trl import SFTTrainer, SFTConfig
from os import getenv
import torch

# Load environment variables
load_dotenv()

# Login to huggingface
login(token=getenv("HUGGINGFACE_ACCESS_TOKEN"))

# Load bitsandbytes config
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
                                bnb_4bit_compute_dtype="float16", bnb_4bit_use_double_quant=False)

# Load the model and tokenizer corresponding to the model
model_name = "meta-llama/Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_name, quantization_config=bnb_config, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# Load the dataset
dataset = load_dataset(
    "json", data_files="/home/user/s25/documents.jsonl", split="train")

# Define tokenization function and tokenize the dataset


def tokenize(examples):
    inputs = tokenizer(examples["document"])
    return inputs


tokenized_dataset = dataset.map(
    tokenize, batched=True, remove_columns=["document"])

# Instantiate data collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Load LoRA configuration
peft_config = LoraConfig(
    r=64, lora_alpha=16, lora_dropout=0, task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])

# Specify the training arguments
trainings_arguments = SFTConfig(output_dir="/data/projects/s25/Llama-3.1-8B-Instruct-lora-v6-1epochs", save_strategy="steps", save_steps=500, save_total_limit=1,
                                per_device_train_batch_size=2, num_train_epochs=1, learning_rate=5e-4, weight_decay=0.01, logging_dir="/data/projects/s25/Llama-3.1-8B-Instruct-lora-v6-1epochs-log", logging_steps=50, report_to="none", fp16=True, bf16=False, dataset_text_field=None)

# Set up trainer
trainer = SFTTrainer(model=model, args=trainings_arguments,
                     train_dataset=tokenized_dataset, processing_class=tokenizer, data_collator=data_collator, peft_config=peft_config)

# Train the model
trainer.train()

This issue is very similar to the following already existing posts:

However, the solutions provided there did not help me solve the problem.

Lastly, the versions of the most relevant packages (not necessarily enough to run the script, but I was character-limited for this post).

accelerate                1.7.0              pyhe01879c_0    conda-forge   
bitsandbytes              0.46.0          cuda126_py313hde49398_0    conda-forge                                                                                                                                                                  
datasets                  3.6.0              pyhd8ed1ab_0    conda-forge
huggingface_hub           0.33.0             pyhd8ed1ab_0    conda-forge                                                                                                                                                                                                                                                                   
numpy                     2.3.0           py313h17eae1a_0    conda-forge                                                                                                                                                                               
pandas                    2.3.0           py313ha87cce1_0    conda-forge                                                                                                                                                                        
pip                       25.1.1             pyh145f28c_0    conda-forge                                                                                                                                                                               
python                    3.13.2          hf636f53_101_cp313    conda-forge                                                                                      
python-dateutil           2.9.0.post0        pyhff2d567_1    conda-forge                                                                                         
python-dotenv             1.1.0              pyh29332c3_1    conda-forge                                                                                         
python-gil                3.13.5             h4df99d1_101    conda-forge                                                                                         
python-tzdata             2025.2             pyhd8ed1ab_0    conda-forge                                                                                         
python-xxhash             3.5.0           py313h536fd9c_2    conda-forge                                                                                         
python_abi                3.13                    7_cp313    conda-forge                                                                                         
pytorch                   2.7.0           cuda126_generic_py313_h14c909a_200    conda-forge                                                                      
tokenizers                0.21.1          py313h1191936_0    conda-forge
torch                     2.6.0+cu126              pypi_0    pypi
torchaudio                2.6.0+cu126              pypi_0    pypi
torchvision               0.21.0+cu126             pypi_0    pypi
transformers              4.52.4             pyhd8ed1ab_0    conda-forge
trl                       0.18.2             pyhd8ed1ab_0    conda-forge

I am very grateful for any support! Thank you very much!

1 Like

If so, it may be an unresolved compatibility issue between accelerate and bitsandbytes?

Thanks for the information, however, I have tried running the script without the bitsandbytes configuration (and also with the bitsandbytes package removed) by just utilizing more GPUs, however the error seems to persist.

So essentially by simply loading the model as follows:

model_name = "meta-llama/Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

(And by the way launching the script with: CUDA_VISIBLE_DEVICES=0,1 python finetune_model_LoRA.py)

1 Like

UPDATE: at least for now the problem seems to be fixed. I downgraded the transformers library to version 4.49.0, used the transfomers.Trainer instead of the SFTTrainer and modified the loading of the model to the following.

# Load bitsandbytes config
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
                                bnb_4bit_compute_dtype="float16", bnb_4bit_use_double_quant=False)

# Load LoRA configuration
peft_config = LoraConfig(
    r=64, lora_alpha=16, lora_dropout=0, task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])

# Load the model and prepare it for peft finetuning
model_name = "meta-llama/Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_name, quantization_config=bnb_config, device_map="auto")

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)

Maybe this will help someone in the future!

1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.