Got "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!" on my custom model

I am trying to add a new linear layer to Qwen 2.5 and use lora to make SFT. When I use it locally with small max_length, it works fine. But it reports error when using 2 GPUs. The error message is as follow:

Traceback (most recent call last):
  File "/home/csun/project/information-retrival/extract_token/train_extract_token_all_loss.py", line 244, in <module>
    trainer.train()
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/transformers/trainer.py", line 2171, in train
    return inner_training_loop(
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/transformers/trainer.py", line 2531, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/transformers/trainer.py", line 3675, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
  File "/home/csun/project/information-retrival/extract_token/train_extract_token_all_loss.py", line 171, in compute_loss
    outputs = model(
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 193, in forward
    outputs = self.parallel_apply(replicas, inputs, module_kwargs)
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 212, in parallel_apply
    return parallel_apply(
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 126, in parallel_apply
    output.reraise()
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/torch/_utils.py", line 733, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 96, in _worker
    output = module(*input, **kwargs)
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/csun/project/information-retrival/extract_token/CustomLM.py", line 68, in forward
    outputs = self.base_model(
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/peft/peft_model.py", line 1719, in forward
    return self.base_model(
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 197, in forward
    return self.model.forward(*args, **kwargs)
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 816, in forward
    outputs = self.model(
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 551, in forward
    position_embeddings = self.rotary_emb(hidden_states, position_ids)
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/creat/anaconda3/envs/information-retrival/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 331, in forward
    freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument mat2 in method wrapper_CUDA_bmm)

and here is my custome model code:

import torch

class PositionSmoother(torch.nn.Module):
    def __init__(self, tau=0.1):
        super().__init__()
        self.tau = tau

    def forward(self, logits):
        """
        修正后的前向传播逻辑
        输入形状: [batch_size, 2, seq_length]
        输出形状: (start_positions [batch_size], end_positions [batch_size])
        """
        # 分离起始和结束位置的logits
        start_logits = logits[:, 0, :]  # [batch_size, seq_length]
        end_logits = logits[:, 1, :]  # [batch_size, seq_length]

        # 定义采样函数
        def sample_position(pos_logits):
            # 使用Gumbel-Softmax获取采样结果
            sampled = torch.nn.functional.gumbel_softmax(
                pos_logits,
                tau=self.tau,
                hard=True
            )  # [batch_size, seq_length]

            # 生成位置索引
            positions = torch.arange(
                pos_logits.size(1),  # 使用seq_length维度
                device=pos_logits.device
            ).float()  # [seq_length]

            # 计算加权位置值
            return torch.sum(sampled * positions, dim=1)  # [batch_size]

        # 分别处理起始和结束位置
        start_pos = sample_position(start_logits)
        end_pos = sample_position(end_logits)

        return start_pos, end_pos


class CustomLM(torch.nn.Module):
    def __init__(self, base_model, max_seq_length, device):
        super().__init__()
        self.device = device
        self.base_model = base_model.to(self.device)
        self.max_seq_length = max_seq_length

        # 修正后的位置预测头
        self.position_head = torch.nn.Linear(
            base_model.config.hidden_size,
            2 * max_seq_length  # 输出起始和结束位置分布
        )
        self.position_smoother = PositionSmoother(tau=0.1)
        # 添加位置预测头初始化控制
        torch.nn.init.xavier_normal_(self.position_head.weight, gain=0.1)
        torch.nn.init.constant_(self.position_head.bias, 0.0)

    def forward(self, input_ids, attention_mask, labels=None):
        input_ids = input_ids.to(self.device)
        attention_mask = attention_mask.to(self.device)
        labels = labels.to(self.device)
        print("Model parameter device:", self.base_model.device)
        print("Input tensor device:", input_ids)
        print("attention_mask tensor device:", attention_mask)
        print("labels tensor device:", labels)
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_hidden_states=True
        )

        # 获取隐藏状态并调整形状
        last_hidden = outputs.hidden_states[-1]  # [batch_size, seq_len, hidden_size]

        # 使用平均池化获取序列特征
        pooled = last_hidden.mean(dim=1)  # [batch_size, hidden_size]

        # 生成位置logits
        position_logits = self.position_head(pooled)  # [batch_size, 2*max_seq_length]

        # 调整形状为 [batch_size, 2, max_seq_length]
        position_logits = position_logits.view(-1, 2, self.max_seq_length)

        # 应用位置平滑
        start_pos, end_pos = self.position_smoother(position_logits)

        return {
            "lm_loss": outputs.loss,
            "start_position": start_pos,
            "end_position": end_pos,
            "hidden_states": outputs.hidden_states
        }

and training code is as:

import os
import torch
from accelerate import dispatch_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback
from CustomLM import CustomLM, PositionSmoother
import torch.nn.functional as F
from cal_embedding_by_start_end import cal_embedding
from torch.nn import DataParallel

# 1. 加载已下载的 LLaMA-8B 模型
# model_name = "/home/admin01/projects/lab/llm/Meta-Llama-3-8B-Instruct/rag/Llama-3.1-8B"  # 替换为本地 LLaMA-8B 模型路径
# model_name = "/home/admin01/projects/lab/llm/Qwen2.5-72B-Instruct"
# model_name = "/data/csun/llama-70b"  # 替换为本地 LLaMA-8B 模型路径
# model_name = "/home/admin01/projects/lab/llm/Qwen2.5-7B-Instruct"  # 替换为本地 Qwen 路径
model_name = "/data/csun/Qwen2.5-7B-Instruct"  # 替换为本地 LLaMA-8B 模型路径
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
MAX_LENGTH = 3000


from transformers import BitsAndBytesConfig

# bnb_config = BitsAndBytesConfig(
#     load_in_8bit=True,  # 启用4-bit量化
#     llm_int8_has_fp16_weight=True
# )
# # 确保你有 GPU 可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    attn_implementation='flash_attention_2',
    # quantization_config=bnb_config  # 指定BitsAndBytesConfig
).to(device)


#
# # 将模型移动到 GPU 或 CPU 上
# model.to(device)
# 2. 准备训练数据
# 使用 Hugging Face Datasets 或加载自定义 JSON 数据
# dataset = load_dataset("json", data_files="/home/admin01/projects/lab/information-retrival/data/finetuning_data.json")  # 替换为你的数据路径
dataset = load_dataset("json", data_files={"train":"train_finetuning_data.json",
                                           "validation":"valid_finetuning_data.json",
                                           "test":"test_finetuning_data.json"})  # 替换为你的数据路径

# 数据预处理函数
def preprocess_function(example):
    # 拼接 prompt 与 input,构造完整的上下文
    # full_text = example["instruction"]+ "\n\n" + example["input"]+ "\n-开始token位置:" + str(example["start"])+ "\n-结束token位置:" + str(example["end"])
    # 拼接 prompt 和输入部分
    add_len = len(tokenizer.tokenize(example["instruction"] + "\n\n**Input**:"))
    prompt_text = example["instruction"]+ "\n\n" + example["input"]
    target_text = "\n-开始token位置:" + str(example["start"]+add_len)+ "\n-结束token位置:" + str(example["end"]+add_len)

    # 1. 先单独对 target_text 分词,计算其长度(不添加特殊 token)
    target_encoded = tokenizer(target_text, add_special_tokens=False, truncation=False)
    target_length = len(target_encoded["input_ids"])

    # 2. 计算 prompt 允许的最大长度(总长度 - target长度 - 特殊 token)
    max_length = MAX_LENGTH
    # 假设模型需要添加 BOS/EOS(如 GPT),则占 2 个 token
    special_tokens_count = 1
    prompt_max_length = max_length - target_length - special_tokens_count
    # prompt_max_length = max_length - target_length

    # 3. 对 prompt_text 单独分词并截断(不添加特殊 token)
    prompt_encoded = tokenizer(
        prompt_text,
        truncation=True,
        max_length=prompt_max_length,
        add_special_tokens=False
    )
    prompt_ids = prompt_encoded["input_ids"]

    # 4. 手动拼接完整序列(添加 BOS + prompt + target + EOS)
    input_ids = prompt_ids + target_encoded["input_ids"] + [tokenizer.eos_token_id]

    # 5. 处理填充和截断
    if len(input_ids) > max_length:
        # 如果仍然超长(因特殊 token 或计算误差),截断 prompt 部分
        input_ids = input_ids[:max_length - 1] + [tokenizer.eos_token_id]
    attention_mask = [1] * len(input_ids)

    # 填充到 max_length
    if len(input_ids) < max_length:
        pad_length = max_length - len(input_ids)
        input_ids += [tokenizer.pad_token_id] * pad_length
        attention_mask += [0] * pad_length

    # 6. 计算 target 起始位置(BOS + prompt 长度)
    target_position = len(prompt_ids)

    # 7. 构造 labels(仅 target 部分有效)
    labels = [-100] * target_position + input_ids[target_position:]

    # output ids
    # output_ids = tokenizer.encode(example["output"])
    start_label = [example["start"]+add_len]
    end_label = [example["end"]+add_len]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
        "target_position": target_position,
        # "output_ids": output_ids,
        "start_label": start_label,
        "end_label": end_label,
        "prompt_max_length": prompt_max_length
    }

tokenized_dataset = {}
tokenized_dataset["train"] = dataset["train"].map(preprocess_function, batched=False).filter(lambda x:x["end_label"][0]<x["prompt_max_length"])
tokenized_dataset["validation"] = dataset["validation"].map(preprocess_function, batched=False).filter(lambda x:x["end_label"][0]<x["prompt_max_length"])
tokenized_dataset["test"] = dataset["test"].map(preprocess_function, batched=False).filter(lambda x:x["end_label"][0]<x["prompt_max_length"])

# 3. 配置 QLoRA 参数
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,  # 因果语言建模任务
    inference_mode=False,
    r=8,  # LoRA 矩阵秩
    lora_alpha=32,  # 缩放因子
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05  # Dropout 概率
)

# 将 LoRA 应用于模型
model = get_peft_model(model, lora_config).to(device)
model = CustomLM(model, max_seq_length=MAX_LENGTH, device=device).to(device)
# 4. 设置训练参数
training_args = TrainingArguments(
    output_dir="/data/csun/model/qwen7b_extract_token",
    # output_dir="../model/qwen7b_extract_token",
    num_train_epochs=3,
    # per_device_train_batch_size=1,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=1,
    learning_rate=1e-4,
    warmup_steps=100,
    bf16=True,
    logging_dir="./qwen7b_extract_token",
    logging_steps=10,
    save_total_limit=2,
    evaluation_strategy="steps",  # 每个epoch评估一次
    eval_steps=100,
    save_strategy="steps",       # 每个epoch保存一次模型
    save_steps=100,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    load_best_model_at_end=True,
    report_to="tensorboard"
)


class CustomTrainer(Trainer):
    def __init__(self, *args, max_teacher_forcing_steps=5000, **kwargs):
        super().__init__(*args, **kwargs)
        self.max_teacher_forcing_steps = max_teacher_forcing_steps
        # 初始化损失函数
        self.cos_loss = torch.nn.CosineEmbeddingLoss()
        self._signature_columns = ['input_ids', 'attention_mask', 'labels', 'label', 'labels', 'label_ids', 'output_ids',
                              'start_label', 'end_label']

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):

        # 前向传播
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            labels=inputs.get("labels")  # 更安全的获取方式
        )
        input_len = inputs['input_ids'].shape[1]
        # ========== 关键修改部分 ==========
        # 1. 位置损失计算(自动处理batch平均)
        # pos_loss = (F.mse_loss(outputs["start_position"], inputs["start_label"], reduction='sum') + F.mse_loss(
        #     outputs["end_position"], inputs["end_label"], reduction='sum')) / num_items_in_batch
        '''1、计算start和end各自的位置损失'''
        pos_loss = F.mse_loss(outputs["start_position"]/input_len, inputs["start_label"].squeeze(1)/input_len, reduction='mean') + F.mse_loss(
            outputs["end_position"]/input_len, inputs["end_label"].squeeze(1)/input_len, reduction='mean')
        '''2、如果start在end之后,加入位置惩罚'''
        hidden_states = outputs["hidden_states"][-1]
        batch_size = hidden_states.size(0)
        '''如果训练轮次小于指定轮次,直接使用标记位置提取hidden_state,防止预测位置偏差太大'''
        if self.state.global_step < self.max_teacher_forcing_steps:
            starts = inputs['start_label']
            ends = inputs['end_label']
        else:
            # 获取预测位置(需确保可微分)
            starts = outputs["start_position"].unsqueeze(1)  # 转换为整数索引
            ends = outputs["end_position"].unsqueeze(1)
        pos_mask = (starts < ends).long()
        neg_loss = F.mse_loss(starts/input_len, ends/input_len, reduction='mean') * (1-pos_mask)
        '''3、如果start在end之前,提取hidden_states,计算余弦相似度损失'''
        # 边界保护
        seq_len = hidden_states.size(1)
        abs_starts = min(starts, ends)
        abs_ends = max(starts, ends)
        # abs_starts = torch.clamp(abs_starts, 0, seq_len - 1).squeeze(1)
        # abs_ends = torch.clamp(abs_ends, 0, seq_len).squeeze(1)

        # # 向量化提取特征
        # input_embs = []
        # for i in range(batch_size):
        #     span_emb = hidden_states[i, starts[i]:ends[i]].mean(dim=0)
        #     input_embs.append(span_emb)
        # input_embs = torch.stack(input_embs)  # [batch_size, hidden_size]
        input_embs = cal_embedding(abs_starts, abs_ends, seq_len, hidden_states, self.args.device)
        output_embs = cal_embedding(inputs['start_label'], inputs['end_label'], seq_len, hidden_states, self.args.device)

        # 计算余弦相似度损失(自动平均)
        sim_loss = self.cos_loss(
            input_embs,
            output_embs,
            torch.ones(batch_size, device=self.args.device)  # 正样本标签
        )
        sim_loss = sim_loss * pos_mask
        # 3. 总损失计算(不需要手动除batch_size)
        total_loss = (
                outputs["lm_loss"] +  # 语言模型损失(已自动平均)
                10*pos_loss + # 开始和结束位置损失(已平均)
                10*neg_loss +  # 开始在结束后的额外惩罚
                10*sim_loss  # 相似度损失(已平均)
        )
        outputs["total_loss"] = total_loss
        return (total_loss, outputs) if return_outputs else total_loss

# 假设模型和数据集已准备好
trainer = CustomTrainer(
    model=model,  # 您的预训练模型
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    tokenizer=tokenizer,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
    max_teacher_forcing_steps=1000
    # compute_metrics=compute_metrics  # 使用自定义评估函数
)

# 开始训练
trainer.train()
# 保存微调后的模型
model.save_pretrained("/data/csun/model/qwen7b_extract_token_best")
tokenizer.save_pretrained("/data/csun/model/qwen7b_extract_token_best")

I have tried methods mentioned in above topics but all won’t work. Anyone can help?

1 Like

Remove the device_map="auto" argument in model initialization and try.

2 Likes

And perhaps a bug in qwen code.

1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.