Traceback (most recent call last):
File "/home/user/xutian/gpt2-sentiment-trl.py", line 153, in <module>
answer = get_answer(question)
^^^^^^^^^^^^^^^^^^^^
File "/home/user/xutian/gpt2-sentiment-trl.py", line 95, in get_answer
answer = generate(torch.stack(question))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/xutian/gpt2-sentiment-trl.py", line 81, in generate
return model_ppo.generate(input_ids=input_ids,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/trl/models/modeling_value_head.py", line 202, in generate
return self.pretrained_model.generate(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/transformers/generation/utils.py", line 1764, in generate
return self.sample(
^^^^^^^^^^^^
File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/transformers/generation/utils.py", line 2861, in sample
outputs = self(
^^^^^
File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1074, in forward
transformer_outputs = self.transformer(
^^^^^^^^^^^^^^^^^
File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 837, in forward
inputs_embeds = self.wte(input_ids)
^^^^^^^^^^^^^^^^^^^
File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/torch/nn/modules/sparse.py", line 162, in forward
return F.embedding(
^^^^^^^^^^^^
File "/home/user/anaconda3/envs/xt_lab/lib/python3.11/site-packages/torch/nn/functional.py", line 2233, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:3! (when checking argument for argument index in method wrapper_CUDA__index_select)
I have put the model_ppo and input_ids on coda:3 before function model_ppo.generate run ,but why is there still an error displaying device conflicts
here is the code
from transformers import AutoTokenizer,AutoModelForSequenceClassification
import random
import torch
from datasets import load_dataset, concatenate_datasets,load_from_disk
from trl import AutoModelForCausalLMWithValueHead,PPOConfig, PPOTrainer
device = 'cuda:3' if torch.cuda.is_available() else 'cpu'
tokenizer = AutoTokenizer.from_pretrained('tokenizer/lvwerra/gpt2-imdb')
tokenizer.pad_token_id = 0
dataset = load_from_disk('dataset/imdb')
dataset = concatenate_datasets([dataset[i] for i in ['train', 'test']])
def f(data):
question = tokenizer.encode(data['text'], add_special_tokens=False)[:5]
return {'question': question}
dataset = dataset.map(f, remove_columns=['label', 'text'])
def f(data):
return len(data['question']) == 5
dataset = dataset.filter(f)
def get_batch_data():
label = random.choices(range(2), k=128)
question = random.choices(dataset, k=128)
question = [i['question'] for i in question]
question = [[tokenizer.convert_tokens_to_ids(str(l))] + q
for l, q in zip(label, question)]
return label, question
get_batch_data()
model_ppo = AutoModelForCausalLMWithValueHead.from_pretrained(
'model/lvwerra/gpt2-imdb').to(device)
model_ppo_ref = AutoModelForCausalLMWithValueHead.from_pretrained(
'model/lvwerra/gpt2-imdb').to(device)
for i in model_ppo_ref.parameters():
i.requires_grad_(False)
tokenizer_cls = AutoTokenizer.from_pretrained(
'tokenizer/lvwerra/distilbert-imdb')
model_cls = AutoModelForSequenceClassification.from_pretrained(
'model/lvwerra/distilbert-imdb').to(device)
for i in model_cls.parameters():
i.requires_grad_(False)
def get_question():
label, question = get_batch_data()
label = torch.LongTensor(label).to(device)
question = [torch.LongTensor(i).to(device) for i in question]
return label, question
label, question = get_question()
def generate(input_ids):
return model_ppo.generate(input_ids=input_ids,
min_length=-1,
top_k=0.0,
top_p=1.0,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=32,
eos_token_id=tokenizer.eos_token_id)
def get_answer(question):
if True:
answer = generate(torch.stack(question))
answer_new = []
for i in answer:
if tokenizer.eos_token_id not in i:
answer_new.append(i.unsqueeze(0))
continue
split = i.tolist().index(tokenizer.eos_token_id) + 1
answer_new.append(i[:split].unsqueeze(0))
answer = answer_new
else:
answer = [generate(i.unsqueeze(0)) for i in question]
answer = [a[0, len(q):] for q, a in zip(question, answer)]
return answer
answer = get_answer(question)
def get_reward(question, answer, label):
token = [q.tolist()[1:] + a.tolist() for q, a in zip(question, answer)]
token = [tokenizer.decode(i) for i in token]
token = tokenizer_cls(token,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt').to(device)
with torch.no_grad():
logits = model_cls(**token).logits
return logits.gather(1, label.reshape(-1, 1)).squeeze(1)
reward = get_reward(question, answer, label)
config = PPOConfig(learning_rate=1e-5, batch_size=128)
trainer = PPOTrainer(config,
model_ppo,
model_ppo_ref,
tokenizer,
dataset=dataset)
import warnings
warnings.filterwarnings('ignore')
for epoch in range(200):
label, question = get_question()
answer = get_answer(question)
reward = get_reward(question, answer, label)
trainer.step(question, answer, [i for i in reward])
print(epoch, reward.mean().item())
question = tokenizer.decode(question[0].tolist())
answer = tokenizer.decode(answer[0].tolist())
reward = reward[0].item()
print(question, '->', answer, reward)
model_ppo.save_pretrained("gpt2-ppo-model")
tokenizer.save_pretrained("gpt2-ppo-tokenizer")