Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:3! (when checking argument for argument index in method wrapper_CUDA__index_select)

Traceback (most recent call last):
  File "/home/user/xutian/gpt2-sentiment-trl.py", line 153, in <module>
    answer = get_answer(question)
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/user/xutian/gpt2-sentiment-trl.py", line 95, in get_answer
    answer = generate(torch.stack(question))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/xutian/gpt2-sentiment-trl.py", line 81, in generate
    return model_ppo.generate(input_ids=input_ids,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/trl/models/modeling_value_head.py", line 202, in generate
    return self.pretrained_model.generate(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/transformers/generation/utils.py", line 1764, in generate
    return self.sample(
           ^^^^^^^^^^^^
  File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/transformers/generation/utils.py", line 2861, in sample
    outputs = self(
              ^^^^^
  File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1074, in forward
    transformer_outputs = self.transformer(
                          ^^^^^^^^^^^^^^^^^
  File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 837, in forward
    inputs_embeds = self.wte(input_ids)
                    ^^^^^^^^^^^^^^^^^^^
  File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
           ^^^^^^^^^^^^
  File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/torch/nn/functional.py", line 2233, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:3! (when checking argument for argument index in method wrapper_CUDA__index_select)

I have put the model_ppo and input_ids on coda:3 before function model_ppo.generate run ,but why is there still an error displaying device conflicts

here is the code

from transformers import AutoTokenizer,AutoModelForSequenceClassification
import random
import torch
from datasets import load_dataset, concatenate_datasets,load_from_disk
from trl import AutoModelForCausalLMWithValueHead,PPOConfig, PPOTrainer



device = 'cuda:3' if torch.cuda.is_available() else 'cpu'

tokenizer = AutoTokenizer.from_pretrained('tokenizer/lvwerra/gpt2-imdb')
tokenizer.pad_token_id = 0



dataset = load_from_disk('dataset/imdb')
dataset = concatenate_datasets([dataset[i] for i in ['train', 'test']])


def f(data):
    question = tokenizer.encode(data['text'], add_special_tokens=False)[:5]
    return {'question': question}


dataset = dataset.map(f, remove_columns=['label', 'text'])


def f(data):
    return len(data['question']) == 5


dataset = dataset.filter(f)

def get_batch_data():

    label = random.choices(range(2), k=128)

    question = random.choices(dataset, k=128)
    question = [i['question'] for i in question]


    question = [[tokenizer.convert_tokens_to_ids(str(l))] + q
                for l, q in zip(label, question)]

    return label, question


get_batch_data()

model_ppo = AutoModelForCausalLMWithValueHead.from_pretrained(
    'model/lvwerra/gpt2-imdb').to(device)
model_ppo_ref = AutoModelForCausalLMWithValueHead.from_pretrained(
    'model/lvwerra/gpt2-imdb').to(device)

for i in model_ppo_ref.parameters():
    i.requires_grad_(False)


tokenizer_cls = AutoTokenizer.from_pretrained(
    'tokenizer/lvwerra/distilbert-imdb')
model_cls = AutoModelForSequenceClassification.from_pretrained(
    'model/lvwerra/distilbert-imdb').to(device)

for i in model_cls.parameters():
    i.requires_grad_(False)

def get_question():
    label, question = get_batch_data()
    label = torch.LongTensor(label).to(device)

    question = [torch.LongTensor(i).to(device) for i in question]

    return label, question


label, question = get_question()



def generate(input_ids):
    return model_ppo.generate(input_ids=input_ids,
                              min_length=-1,
                              top_k=0.0,
                              top_p=1.0,
                              do_sample=True,
                              pad_token_id=tokenizer.pad_token_id,
                              max_new_tokens=32,
                              eos_token_id=tokenizer.eos_token_id)


def get_answer(question):

    if True:
        answer = generate(torch.stack(question))

        answer_new = []
        for i in answer:
            if tokenizer.eos_token_id not in i:
                answer_new.append(i.unsqueeze(0))
                continue

            split = i.tolist().index(tokenizer.eos_token_id) + 1

            answer_new.append(i[:split].unsqueeze(0))
        answer = answer_new
    else:
        answer = [generate(i.unsqueeze(0)) for i in question]

    answer = [a[0, len(q):] for q, a in zip(question, answer)]

    return answer


answer = get_answer(question)


def get_reward(question, answer, label):
    token = [q.tolist()[1:] + a.tolist() for q, a in zip(question, answer)]
    token = [tokenizer.decode(i) for i in token]

    token = tokenizer_cls(token,
                          padding=True,
                          truncation=True,
                          max_length=512,
                          return_tensors='pt').to(device)

    with torch.no_grad():
        logits = model_cls(**token).logits

    return logits.gather(1, label.reshape(-1, 1)).squeeze(1)


reward = get_reward(question, answer, label)


config = PPOConfig(learning_rate=1e-5, batch_size=128)

trainer = PPOTrainer(config,
                     model_ppo,
                     model_ppo_ref,
                     tokenizer,
                     dataset=dataset)

import warnings

warnings.filterwarnings('ignore')

for epoch in range(200):
    label, question = get_question()
    answer = get_answer(question)
    reward = get_reward(question, answer, label)

    trainer.step(question, answer, [i for i in reward])
    print(epoch, reward.mean().item())

    question = tokenizer.decode(question[0].tolist())
    answer = tokenizer.decode(answer[0].tolist())
    reward = reward[0].item()

    print(question, '->', answer, reward)

model_ppo.save_pretrained("gpt2-ppo-model")
tokenizer.save_pretrained("gpt2-ppo-tokenizer")