RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:2 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

I’m trying to fine-tune Llama-2-13b-chat for machine translation with lightning trainer and deepspeed. The test reported here is on 4 rtx 3090 with an 8bit quantized version of the model and LoRa and I get this error. This is the code used for training




trainer = L.Trainer(
    devices=4,
    accelerator="gpu",
    #strategy="ddp",
    strategy=DeepSpeedStrategy(
         stage=3,
         offload_optimizer=True,
         offload_parameters=True,
     ),
    max_steps=10,
    #precision="16-mixed",
    enable_checkpointing=True,
    accumulate_grad_batches=args.accumulate_grad_batches,
    log_every_n_steps=20,
    val_check_interval=30,
    limit_val_batches=200,
    #default_root_dir="checkpoints",
    callbacks=[checkpoint_callback],
    gradient_clip_val=args.gradient_clip_val,
    logger=mlflow_logger,
)
print(args.quantized)
print(quantization_config)
model = MTModel(model_name=args.model_name, pad_token_id=tokenizer.pad_token_id, inference=False,
                 learning_rate=args.learning_rate, weight_decay=args.weight_decay, 
                 betas=args.betas, quantization_config=quantization_config,peft=args.peft, load_in_8bit=False)

trainer.fit(model, train_data_loader, test_data_loader)

this is MTModel


quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_8bit_compute_dtype=torch.float16,
    bnb_8bit_quant_type="nf8",
    bnb_8bit_use_double_quant=True,
)

class MTModel(L.LightningModule):
    def __init__(self, model_name: str, pad_token_id: int,inference: bool, 
                 learning_rate: float = 1e-4, weight_decay: float = 0.0, 
                 betas: tuple = (0.9, 0.95), quantization_config:LoraConfig = None,
                 load_in_8bit:bool=False,peft:bool=False):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = AutoModelForCausalLM.from_pretrained(model_name,trust_remote_code=True,device_map="auto",
                                                          quantization_config=quantization_config,load_in_8bit=load_in_8bit)

        if peft:
            peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=inference, r=4, lora_alpha=4, lora_dropout=0.01)
            self.model = get_peft_model(self.model, peft_config)
            
        

    def training_step(self, batch, batch_idx):
        get_accelerator().empty_cache()
        input_ids, target_start_idx = batch
        logits = self.model(input_ids).logits
        loss = mt_loss(logits, input_ids, target_start_idx, self.hparams.pad_token_id)
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        input_ids, target_start_idx = batch
        logits = self.model(input_ids).logits
        loss = mt_loss(logits, input_ids, target_start_idx, self.hparams.pad_token_id)
        self.log("val_loss", loss, sync_dist=True)
        self.log("val_ppl", torch.exp(loss), sync_dist=True)
        return loss

    def generate(self, batch, **kwargs):
        return self.model.generate(batch,**kwargs)
    
    def configure_optimizers(self):
        return DeepSpeedCPUAdam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay, betas=self.hparams.betas)

and here is the error

                                  β”‚
β”‚                                                                                                  β”‚
β”‚ /home/michele/miniconda3/lib/python3.10/site-packages/torch/nn/functional.py:2235 in embedding   β”‚
β”‚                                                                                                  β”‚
β”‚   2232 β”‚   β”‚   #   torch.embedding_renorm_                                                       β”‚
β”‚   2233 β”‚   β”‚   # remove once script supports set_grad_enabled                                    β”‚
β”‚   2234 β”‚   β”‚   _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)                    β”‚
β”‚ ❱ 2235 β”‚   return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)        β”‚
β”‚   2236                                                                                           β”‚
β”‚   2237                                                                                           β”‚
β”‚   2238 def embedding_bag(                                                                        β”‚
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:2 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

I found one similar similar issue but I get the error also without load_in_8bit=True and anyway in my case the model dispatch should be handled by DeepSpeed. Has anyone encountered similar errors?

Try this:

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(DEVICE)
2 Likes

I got the exact same error with my model. Also using deepspeed and bitsandbytes. Did you manage to solve the issue?
My code runs fine on a single gpu, but fails on multiple. (using A100)

Did you figure out a workaround for this issue?