Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

I am using pre-trained ESM-V2 models from huggingface using QLoRa technique.
Here is my encoder:

    def __init__(self):
        # QLoRa fine-tuning:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.float16
        )
        self.model = EsmModel.from_pretrained(model_name, quantization_config=quantization_config)
        self.model = prepare_model_for_kbit_training(self.model,
                                                     use_gradient_checkpointing=False)

        config = LoraConfig(
            r=8,
            lora_alpha=32,
            target_modules=[
                "query", "key", "value",
                "dense"
                            ],
            lora_dropout=0.05,
            bias="none",
        )
        self.model = get_peft_model(self.model, config)
        .
        .
        .

    def forward(self, x):
        x_sequence = {key: value for key, value in x["sequence"].items()}
        features = self.model(**x_sequence)
        .
        .
        .

Here is my decoder forward code:

    def forward(self, encoder_out, target_input):
        tgt_mask, tgt_padding_mask = create_mask(target_input, self.pad_idx, self.device)
        tgt_embedding = self.embedding(target_input)
        tgt_embedding = self.decoder_pos_drop(tgt_embedding + self.decoder_pos_embed)

        encoder_out = self.encoder_pos_drop(encoder_out + self.encoder_pos_embed)

        encoder_out = encoder_out.transpose(0, 1)
        tgt_embedding = tgt_embedding.transpose(0, 1)

        preds = self.decoder(memory=encoder_out,
                             tgt=tgt_embedding,
                             tgt_mask=tgt_mask,
                             tgt_key_padding_mask=tgt_padding_mask)
        preds = preds.transpose(0, 1)
        return self.output(preds)

This is my training loop:

    accelerator = Accelerator(
        mixed_precision='fp16',
        gradient_accumulation_steps=8
    )

    net, optimizer, dataloaders_dict["train"], scheduler = accelerator.prepare(
        net, optimizer, dataloaders_dict["train"], scheduler
        )

    for i, data in enumerate(tools['train_loader']):
        with accelerator.accumulate(tools['net']):
            embeddings, task_num, sequence, target = data

            target_input = target[:, :-1]
            target_expected = target[:, 1:]

            batch = {"sequence": sequence, "embedding": embeddings, "target_input": target_input}

            preds = tools['net'](batch)
            loss = tools['loss_function'](preds.reshape(-1, preds.shape[-1]), target_expected.reshape(-1))
            loss = torch.mean(loss)

            avg_loss = accelerator.gather(loss.repeat(tools["train_batch_size"])).mean()
            train_loss += avg_loss.item() / tools['accum_iter']

            accelerator.backward(loss)
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(tools['net'].parameters(), tools['grad_clip'])

            tools['optimizer'].step()
            tools['scheduler'].step()
            tools['optimizer'].zero_grad()

I connected an autoregressive decoder to it to create a seq2seq model. My code works pretty well when I use one GPU, but when I set the accelerate config to use two GPUs, I got this error:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

This is my error log:

File "/home/mpngf/projects/JointTraining/train.py", line 90, in train
    preds = tools['net'](batch, mode=0)
  File "/home/mpngf/environments/joint_training/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mpngf/environments/joint_training/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/mpngf/environments/joint_training/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/home/mpngf/environments/joint_training/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mpngf/environments/joint_training/lib/python3.10/site-packages/accelerate/utils/operations.py", line 632, in forward
    return model_forward(*args, **kwargs)
  File "/home/mpngf/environments/joint_training/lib/python3.10/site-packages/accelerate/utils/operations.py", line 620, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/home/mpngf/environments/joint_training/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/mpngf/projects/JointTraining/model.py", line 192, in forward
    encoder_out = self.encoder(batch)
  File "/home/mpngf/environments/joint_training/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mpngf/projects/JointTraining/model.py", line 89, in forward
    features = self.model(**x["sequence"])
  File "/home/mpngf/environments/joint_training/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mpngf/environments/joint_training/lib/python3.10/site-packages/peft/peft_model.py", line 322, in forward
    return self.get_base_model()(*args, **kwargs)
  File "/home/mpngf/environments/joint_training/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mpngf/environments/joint_training/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/mpngf/environments/joint_training/lib/python3.10/site-packages/transformers/models/esm/modeling_esm.py", line 917, in forward
    embedding_output = self.embeddings(
  File "/home/mpngf/environments/joint_training/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mpngf/environments/joint_training/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/mpngf/environments/joint_training/lib/python3.10/site-packages/transformers/models/esm/modeling_esm.py", line 203, in forward
    inputs_embeds = self.word_embeddings(input_ids)
  File "/home/mpngf/environments/joint_training/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mpngf/environments/joint_training/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/mpngf/environments/joint_training/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/home/mpngf/environments/joint_training/lib/python3.10/site-packages/torch/nn/functional.py", line 2210, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

It seams that it is related to this line:

features = self.model(**x_sequence)

I use accelerate for training and have tested my codes on two servers (2xA5000 and 2xTitan RTX) and it got the same error.

I would greatly appreciate anyone who can offer their expertise to help me ensure the functionality of my code across multiple GPUs.

I removed the decoder part, and it seems that the error still exist, and it refers to this part:

features = self.model(**x_sequence)

Does anybody know what is my problem?