RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

I’m trying to repurpose the Phi3 model for sentiment analysis (specifically regression.) The Phi3 model in the transformers package has a Phi3ForSequenceClassification class which has a regression option, but there are no pretrained weights that I’m aware of, so I’m trying to load the pretrained weights of the Phi3ForCausalLM class into the model of the Phi3ForSequenceClassification class, then train the additional linear layer, then finetune the whole model. I’m running script below on a remote server using SLURM.

from transformers import DataCollatorWithPadding, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, AutoTokenizer, Trainer
from transformers.models.phi3.modeling_phi3 import Phi3ForSequenceClassification
from peft import get_peft_model, LoraConfig, TaskType
from torch.utils.data import Dataset
import torch

transformer = AutoModelForCausalLM.from_pretrained(
    './microsoft/Phi-3-mini-4k-instruct',
    use_cache = True,
    trust_remote_code = True,
    attn_implementation = 'eager',
    torch_dtype = torch.float16,
    quantization_config = BitsAndBytesConfig(
        load_in_8bit = True,
        bnb_4bit_compute_dtype = torch.float16
    ),
    device_map = 'auto'
)

transformer.config.num_labels = 1
transformer.config.problem_type = 'regression'
print(transformer.config)

model = Phi3ForSequenceClassification(transformer.config)
model.model = transformer.get_decoder()
model.score = model.score.to(torch.float16)

for n, m in model.named_parameters():
    try:
        print(f'{n} -> {m.device}')
    except Exception as e:
        print(f'{n} -> NO DEVICE ({e})')

model.train()

peftConfig = LoraConfig(
    r = 16,
    lora_alpha = 32,
    lora_dropout = 0.05,
    bias = 'none',
    task_type = TaskType.SEQ_CLS,
    target_modules = [n for n, m in model.named_modules() if type(m).__name__ == 'Linear']
)

model = get_peft_model(model, peftConfig)
model.print_trainable_parameters()

tokenizer = AutoTokenizer.from_pretrained('./microsoft/Phi-3-mini-4k-instruct')
tokenizer.model_max_length = 1024
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
tokenizer.padding_side = 'right'

train_texts = [
    'The capital of France is Paris. Paris is located in the northern part of France and is known for its iconic landmarks such as the Eiffel Tower, the Louvre Museum, and Notre-Dame Cathedral. It is also famous for its art, culture, fashion, and cuisine.' for _ in range(1_000)
] 
train_labels = [0 for _ in range(1_000)]

test_texts = [
    'The capital of France is Paris. Paris is located in the northern part of France and is known for its iconic landmarks such as the Eiffel Tower, the Louvre Museum, and Notre-Dame Cathedral. It is also famous for its art, culture, fashion, and cuisine.' for _ in range(100)
] 
test_labels = [0 for _ in range(100)]

train_encodings = tokenizer(train_texts, truncation = True, padding = 'max_length', max_length = 1024)
test_encodings = tokenizer(test_texts, truncation = True, padding = 'max_length', max_length = 1024)

class TextDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]).long() if key == 'input_ids' else torch.tensor(val[idx]).to(torch.float16) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx]).to(torch.float16)
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = TextDataset(train_encodings, train_labels)
test_dataset = TextDataset(test_encodings, test_labels)

dataCollator = DataCollatorWithPadding(tokenizer = tokenizer)

trainingArgs = TrainingArguments(
    output_dir = 'checkpoint',
    learning_rate = 2e-5,
    per_device_train_batch_size = 2,
    per_device_eval_batch_size = 2,
    num_train_epochs = 2,
    weight_decay = 0.01,
    eval_strategy = 'epoch',
    save_strategy = 'epoch',
    load_best_model_at_end = True,
    push_to_hub = False,
    fp16 = True
)

trainer = Trainer(
    model = model,
    args = trainingArgs,
    train_dataset = train_dataset,
    eval_dataset = test_dataset,
    tokenizer = tokenizer,
    data_collator = dataCollator
)

trainer.train()

Here is the SLURM file I’m using:

#!/bin/bash
#SBATCH --output=%j.out
#SBATCH --error=%j.err
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --gpus-per-node=4
#SBATCH --cpus-per-task=10
#SBATCH --time=1:00:00
#SBATCH --mem=64G
#SBATCH --mail-user=MY_EMAIL@EMAIL.EMAIL
#SBATCH --mail-type=ALL

module load python

ENVDIR=$(mktemp -d)

python -m venv $ENVDIR
source $ENVDIR/bin/activate

pip install --no-index --upgrade pip
pip install --no-index bitsandbytes peft transformers torch matplotlib numpy

export MASTER_ADDR=$(hostname -s)
export MASTER_PORT=$(shuf -i 10000-65500 -n 1)
export WORLD_SIZE=$SLURM_NTASKS
export RANK=$SLURM_PROCID

srun python -u del.py

deactivate
rm -rf $ENVDIR

And here is the error file:

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attenton` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:02<00:02,  2.84s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.09s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.20s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

  0%|          | 0/250 [00:00<?, ?it/s]You are not running the flash-attention implementation, expect numerical differences.
Traceback (most recent call last):
  File "/project/6004929/MY_USERNAME/del.py", line 111, in <module>
    trainer.train()
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/transformers/trainer.py", line 1932, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/transformers/trainer.py", line 2268, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/transformers/trainer.py", line 3307, in training_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/transformers/trainer.py", line 3338, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 185, in forward
    outputs = self.parallel_apply(replicas, inputs, module_kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 108, in parallel_apply
    output.reraise()
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/torch/_utils.py", line 705, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/peft/peft_model.py", line 1238, in forward
    return self.base_model(
           ^^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/peft/tuners/tuners_utils.py", line 179, in forward
    return self.model.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/transformers/models/phi3/modeling_phi3.py", line 1510, in forward
    model_outputs = self.model(
                    ^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/MY_USERNAME/.cache/huggingface/modules/transformers_modules/June6_Model/modeling_phi3.py", line 1164, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/MY_USERNAME/.cache/huggingface/modules/transformers_modules/June6_Model/modeling_phi3.py", line 882, in forward
    hidden_states = self.input_layernorm(hidden_states)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmp.7T5FEiUvq9/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/MY_USERNAME/.cache/huggingface/modules/transformers_modules/June6_Model/modeling_phi3.py", line 95, in forward
    return self.weight * hidden_states.to(input_dtype)
           ~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

srun: error: cdr902: task 0: Exited with exit code 1
srun: Terminating StepId=35961212.0

And here is the output file:

Phi3Config {
  "_name_or_path": "./June6_Model",
  "architectures": [
    "Phi3ForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "auto_map": {
    "AutoConfig": "configuration_phi3.Phi3Config",
    "AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM"
  },
  "bos_token_id": 1,
  "embd_pdrop": 0.0,
  "eos_token_id": 32000,
  "hidden_act": "silu",
  "hidden_size": 3072,
  "id2label": {
    "0": "LABEL_0"
  },
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "label2id": {
    "LABEL_0": 0
  },
  "max_position_embeddings": 131072,
  "model_type": "phi3",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "original_max_position_embeddings": 4096,
  "pad_token_id": 32000,
  "problem_type": "regression",
  "quantization_config": {
    "_load_in_4bit": false,
    "_load_in_8bit": true,
    "bnb_4bit_compute_dtype": "float16",
    "bnb_4bit_quant_storage": "uint8",
    "bnb_4bit_quant_type": "fp4",
    "bnb_4bit_use_double_quant": false,
    "llm_int8_enable_fp32_cpu_offload": false,
    "llm_int8_has_fp16_weight": false,
    "llm_int8_skip_modules": null,
    "llm_int8_threshold": 6.0,
    "load_in_4bit": false,
    "load_in_8bit": true,
    "quant_method": "bitsandbytes"
  },
  "resid_pdrop": 0.0,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "long_factor": [
      1.0299999713897705,
      1.0499999523162842,
      1.0499999523162842,
      1.0799999237060547,
      1.2299998998641968,
      1.2299998998641968,
      1.2999999523162842,
      1.4499999284744263,
      1.5999999046325684,
      1.6499998569488525,
      1.8999998569488525,
      2.859999895095825,
      3.68999981880188,
      5.419999599456787,
      5.489999771118164,
      5.489999771118164,
      9.09000015258789,
      11.579999923706055,
      15.65999984741211,
      15.769999504089355,
      15.789999961853027,
      18.360000610351562,
      21.989999771118164,
      23.079999923706055,
      30.009998321533203,
      32.35000228881836,
      32.590003967285156,
      35.56000518798828,
      39.95000457763672,
      53.840003967285156,
      56.20000457763672,
      57.95000457763672,
      59.29000473022461,
      59.77000427246094,
      59.920005798339844,
      61.190006256103516,
      61.96000671386719,
      62.50000762939453,
      63.3700065612793,
      63.48000717163086,
      63.48000717163086,
      63.66000747680664,
      63.850006103515625,
      64.08000946044922,
      64.760009765625,
      64.80001068115234,
      64.81001281738281,
      64.81001281738281
    ],
    "short_factor": [
      1.05,
      1.05,
      1.05,
      1.1,
      1.1,
      1.1500000000000001,
      1.2000000000000002,
      1.2500000000000002,
      1.3000000000000003,
      1.3500000000000003,
      1.5000000000000004,
      2.000000000000001,
      2.000000000000001,
      2.000000000000001,
      2.000000000000001,
      2.000000000000001,
      2.000000000000001,
      2.000000000000001,
      2.000000000000001,
      2.000000000000001,
      2.000000000000001,
      2.000000000000001,
      2.000000000000001,
      2.000000000000001,
      2.000000000000001,
      2.000000000000001,
      2.000000000000001,
      2.000000000000001,
      2.000000000000001,
      2.000000000000001,
      2.000000000000001,
      2.000000000000001,
      2.0500000000000007,
      2.0500000000000007,
      2.0500000000000007,
      2.1000000000000005,
      2.1000000000000005,
      2.1000000000000005,
      2.1500000000000004,
      2.1500000000000004,
      2.3499999999999996,
      2.549999999999999,
      2.5999999999999988,
      2.5999999999999988,
      2.7499999999999982,
      2.849999999999998,
      2.849999999999998,
      2.9499999999999975
    ],
    "type": "su"
  },
  "rope_theta": 10000.0,
  "sliding_window": 262144,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.42.2",
  "use_cache": true,
  "vocab_size": 32064
}

model.embed_tokens.weight -> cuda:0
model.layers.0.self_attn.o_proj.weight -> cuda:0
model.layers.0.self_attn.qkv_proj.weight -> cuda:0
model.layers.0.mlp.gate_up_proj.weight -> cuda:0
model.layers.0.mlp.down_proj.weight -> cuda:0
model.layers.0.input_layernorm.weight -> cuda:0
model.layers.0.post_attention_layernorm.weight -> cuda:0
model.layers.1.self_attn.o_proj.weight -> cuda:0
model.layers.1.self_attn.qkv_proj.weight -> cuda:0
model.layers.1.mlp.gate_up_proj.weight -> cuda:0
model.layers.1.mlp.down_proj.weight -> cuda:0
model.layers.1.input_layernorm.weight -> cuda:0
model.layers.1.post_attention_layernorm.weight -> cuda:0
model.layers.2.self_attn.o_proj.weight -> cuda:0
model.layers.2.self_attn.qkv_proj.weight -> cuda:0
model.layers.2.mlp.gate_up_proj.weight -> cuda:0
model.layers.2.mlp.down_proj.weight -> cuda:0
model.layers.2.input_layernorm.weight -> cuda:0
model.layers.2.post_attention_layernorm.weight -> cuda:0
model.layers.3.self_attn.o_proj.weight -> cuda:0
model.layers.3.self_attn.qkv_proj.weight -> cuda:0
model.layers.3.mlp.gate_up_proj.weight -> cuda:0
model.layers.3.mlp.down_proj.weight -> cuda:0
model.layers.3.input_layernorm.weight -> cuda:0
model.layers.3.post_attention_layernorm.weight -> cuda:0
model.layers.4.self_attn.o_proj.weight -> cuda:0
model.layers.4.self_attn.qkv_proj.weight -> cuda:0
model.layers.4.mlp.gate_up_proj.weight -> cuda:0
model.layers.4.mlp.down_proj.weight -> cuda:0
model.layers.4.input_layernorm.weight -> cuda:0
model.layers.4.post_attention_layernorm.weight -> cuda:0
model.layers.5.self_attn.o_proj.weight -> cuda:0
model.layers.5.self_attn.qkv_proj.weight -> cuda:0
model.layers.5.mlp.gate_up_proj.weight -> cuda:0
model.layers.5.mlp.down_proj.weight -> cuda:0
model.layers.5.input_layernorm.weight -> cuda:0
model.layers.5.post_attention_layernorm.weight -> cuda:0
model.layers.6.self_attn.o_proj.weight -> cuda:1
model.layers.6.self_attn.qkv_proj.weight -> cuda:1
model.layers.6.mlp.gate_up_proj.weight -> cuda:1
model.layers.6.mlp.down_proj.weight -> cuda:1
model.layers.6.input_layernorm.weight -> cuda:1
model.layers.6.post_attention_layernorm.weight -> cuda:1
model.layers.7.self_attn.o_proj.weight -> cuda:1
model.layers.7.self_attn.qkv_proj.weight -> cuda:1
model.layers.7.mlp.gate_up_proj.weight -> cuda:1
model.layers.7.mlp.down_proj.weight -> cuda:1
model.layers.7.input_layernorm.weight -> cuda:1
model.layers.7.post_attention_layernorm.weight -> cuda:1
model.layers.8.self_attn.o_proj.weight -> cuda:1
model.layers.8.self_attn.qkv_proj.weight -> cuda:1
model.layers.8.mlp.gate_up_proj.weight -> cuda:1
model.layers.8.mlp.down_proj.weight -> cuda:1
model.layers.8.input_layernorm.weight -> cuda:1
model.layers.8.post_attention_layernorm.weight -> cuda:1
model.layers.9.self_attn.o_proj.weight -> cuda:1
model.layers.9.self_attn.qkv_proj.weight -> cuda:1
model.layers.9.mlp.gate_up_proj.weight -> cuda:1
model.layers.9.mlp.down_proj.weight -> cuda:1
model.layers.9.input_layernorm.weight -> cuda:1
model.layers.9.post_attention_layernorm.weight -> cuda:1
model.layers.10.self_attn.o_proj.weight -> cuda:1
model.layers.10.self_attn.qkv_proj.weight -> cuda:1
model.layers.10.mlp.gate_up_proj.weight -> cuda:1
model.layers.10.mlp.down_proj.weight -> cuda:1
model.layers.10.input_layernorm.weight -> cuda:1
model.layers.10.post_attention_layernorm.weight -> cuda:1
model.layers.11.self_attn.o_proj.weight -> cuda:1
model.layers.11.self_attn.qkv_proj.weight -> cuda:1
model.layers.11.mlp.gate_up_proj.weight -> cuda:1
model.layers.11.mlp.down_proj.weight -> cuda:1
model.layers.11.input_layernorm.weight -> cuda:1
model.layers.11.post_attention_layernorm.weight -> cuda:1
model.layers.12.self_attn.o_proj.weight -> cuda:1
model.layers.12.self_attn.qkv_proj.weight -> cuda:1
model.layers.12.mlp.gate_up_proj.weight -> cuda:1
model.layers.12.mlp.down_proj.weight -> cuda:1
model.layers.12.input_layernorm.weight -> cuda:1
model.layers.12.post_attention_layernorm.weight -> cuda:1
model.layers.13.self_attn.o_proj.weight -> cuda:1
model.layers.13.self_attn.qkv_proj.weight -> cuda:1
model.layers.13.mlp.gate_up_proj.weight -> cuda:1
model.layers.13.mlp.down_proj.weight -> cuda:1
model.layers.13.input_layernorm.weight -> cuda:1
model.layers.13.post_attention_layernorm.weight -> cuda:1
model.layers.14.self_attn.o_proj.weight -> cuda:2
model.layers.14.self_attn.qkv_proj.weight -> cuda:2
model.layers.14.mlp.gate_up_proj.weight -> cuda:2
model.layers.14.mlp.down_proj.weight -> cuda:2
model.layers.14.input_layernorm.weight -> cuda:2
model.layers.14.post_attention_layernorm.weight -> cuda:2
model.layers.15.self_attn.o_proj.weight -> cuda:2
model.layers.15.self_attn.qkv_proj.weight -> cuda:2
model.layers.15.mlp.gate_up_proj.weight -> cuda:2
model.layers.15.mlp.down_proj.weight -> cuda:2
model.layers.15.input_layernorm.weight -> cuda:2
model.layers.15.post_attention_layernorm.weight -> cuda:2
model.layers.16.self_attn.o_proj.weight -> cuda:2
model.layers.16.self_attn.qkv_proj.weight -> cuda:2
model.layers.16.mlp.gate_up_proj.weight -> cuda:2
model.layers.16.mlp.down_proj.weight -> cuda:2
model.layers.16.input_layernorm.weight -> cuda:2
model.layers.16.post_attention_layernorm.weight -> cuda:2
model.layers.17.self_attn.o_proj.weight -> cuda:2
model.layers.17.self_attn.qkv_proj.weight -> cuda:2
model.layers.17.mlp.gate_up_proj.weight -> cuda:2
model.layers.17.mlp.down_proj.weight -> cuda:2
model.layers.17.input_layernorm.weight -> cuda:2
model.layers.17.post_attention_layernorm.weight -> cuda:2
model.layers.18.self_attn.o_proj.weight -> cuda:2
model.layers.18.self_attn.qkv_proj.weight -> cuda:2
model.layers.18.mlp.gate_up_proj.weight -> cuda:2
model.layers.18.mlp.down_proj.weight -> cuda:2
model.layers.18.input_layernorm.weight -> cuda:2
model.layers.18.post_attention_layernorm.weight -> cuda:2
model.layers.19.self_attn.o_proj.weight -> cuda:2
model.layers.19.self_attn.qkv_proj.weight -> cuda:2
model.layers.19.mlp.gate_up_proj.weight -> cuda:2
model.layers.19.mlp.down_proj.weight -> cuda:2
model.layers.19.input_layernorm.weight -> cuda:2
model.layers.19.post_attention_layernorm.weight -> cuda:2
model.layers.20.self_attn.o_proj.weight -> cuda:2
model.layers.20.self_attn.qkv_proj.weight -> cuda:2
model.layers.20.mlp.gate_up_proj.weight -> cuda:2
model.layers.20.mlp.down_proj.weight -> cuda:2
model.layers.20.input_layernorm.weight -> cuda:2
model.layers.20.post_attention_layernorm.weight -> cuda:2
model.layers.21.self_attn.o_proj.weight -> cuda:2
model.layers.21.self_attn.qkv_proj.weight -> cuda:2
model.layers.21.mlp.gate_up_proj.weight -> cuda:2
model.layers.21.mlp.down_proj.weight -> cuda:2
model.layers.21.input_layernorm.weight -> cuda:2
model.layers.21.post_attention_layernorm.weight -> cuda:2
model.layers.22.self_attn.o_proj.weight -> cuda:3
model.layers.22.self_attn.qkv_proj.weight -> cuda:3
model.layers.22.mlp.gate_up_proj.weight -> cuda:3
model.layers.22.mlp.down_proj.weight -> cuda:3
model.layers.22.input_layernorm.weight -> cuda:3
model.layers.22.post_attention_layernorm.weight -> cuda:3
model.layers.23.self_attn.o_proj.weight -> cuda:3
model.layers.23.self_attn.qkv_proj.weight -> cuda:3
model.layers.23.mlp.gate_up_proj.weight -> cuda:3
model.layers.23.mlp.down_proj.weight -> cuda:3
model.layers.23.input_layernorm.weight -> cuda:3
model.layers.23.post_attention_layernorm.weight -> cuda:3
model.layers.24.self_attn.o_proj.weight -> cuda:3
model.layers.24.self_attn.qkv_proj.weight -> cuda:3
model.layers.24.mlp.gate_up_proj.weight -> cuda:3
model.layers.24.mlp.down_proj.weight -> cuda:3
model.layers.24.input_layernorm.weight -> cuda:3
model.layers.24.post_attention_layernorm.weight -> cuda:3
model.layers.25.self_attn.o_proj.weight -> cuda:3
model.layers.25.self_attn.qkv_proj.weight -> cuda:3
model.layers.25.mlp.gate_up_proj.weight -> cuda:3
model.layers.25.mlp.down_proj.weight -> cuda:3
model.layers.25.input_layernorm.weight -> cuda:3
model.layers.25.post_attention_layernorm.weight -> cuda:3
model.layers.26.self_attn.o_proj.weight -> cuda:3
model.layers.26.self_attn.qkv_proj.weight -> cuda:3
model.layers.26.mlp.gate_up_proj.weight -> cuda:3
model.layers.26.mlp.down_proj.weight -> cuda:3
model.layers.26.input_layernorm.weight -> cuda:3
model.layers.26.post_attention_layernorm.weight -> cuda:3
model.layers.27.self_attn.o_proj.weight -> cuda:3
model.layers.27.self_attn.qkv_proj.weight -> cuda:3
model.layers.27.mlp.gate_up_proj.weight -> cuda:3
model.layers.27.mlp.down_proj.weight -> cuda:3
model.layers.27.input_layernorm.weight -> cuda:3
model.layers.27.post_attention_layernorm.weight -> cuda:3
model.layers.28.self_attn.o_proj.weight -> cuda:3
model.layers.28.self_attn.qkv_proj.weight -> cuda:3
model.layers.28.mlp.gate_up_proj.weight -> cuda:3
model.layers.28.mlp.down_proj.weight -> cuda:3
model.layers.28.input_layernorm.weight -> cuda:3
model.layers.28.post_attention_layernorm.weight -> cuda:3
model.layers.29.self_attn.o_proj.weight -> cuda:3
model.layers.29.self_attn.qkv_proj.weight -> cuda:3
model.layers.29.mlp.gate_up_proj.weight -> cuda:3
model.layers.29.mlp.down_proj.weight -> cuda:3
model.layers.29.input_layernorm.weight -> cuda:3
model.layers.29.post_attention_layernorm.weight -> cuda:3
model.layers.30.self_attn.o_proj.weight -> cuda:3
model.layers.30.self_attn.qkv_proj.weight -> cuda:3
model.layers.30.mlp.gate_up_proj.weight -> cuda:3
model.layers.30.mlp.down_proj.weight -> cuda:3
model.layers.30.input_layernorm.weight -> cuda:3
model.layers.30.post_attention_layernorm.weight -> cuda:3
model.layers.31.self_attn.o_proj.weight -> cuda:3
model.layers.31.self_attn.qkv_proj.weight -> cuda:3
model.layers.31.mlp.gate_up_proj.weight -> cuda:3
model.layers.31.mlp.down_proj.weight -> cuda:3
model.layers.31.input_layernorm.weight -> cuda:3
model.layers.31.post_attention_layernorm.weight -> cuda:3
model.norm.weight -> cuda:3
score.weight -> cpu
trainable params: 52,240 || all params: 3,722,683,424 || trainable%: 0.0014

I’ve looked at the similar posts on this forum and none of the solutions worked for me, I keep getting the “expected tensors to be on the same device” error. Any and all help would be greatly appreciated- I’ve been stuck on this error for days.