Hi everyone,
I have an unbalanced dataset, so I’ve been trying to extend the Trainer class as specified on Trainer docs* and in this thread.
So I have the following code:
import torch
import transformers
from torch import nn
from transformers import Trainer
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.get("labels")
# forward pass
outputs = model(**inputs)
logits = outputs.get('logits')
# compute custom loss
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([0.2, 0.3, 0.5]))
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
I then define the trainer and try to train it in the following way:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
tokenized_train = tokenized_train.with_format("torch")
tokenized_val = tokenized_val.with_format("torch")
model.to('cuda')
training_args = TrainingArguments(
output_dir="my_awesome_model",
learning_rate=1e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=5,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
#push_to_hub=True,
)
trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_val,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()
As you can see I also set the dataset format to ‘torch’ as specified in another thread*, but I still get the following error at trainer.train()
:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument weight in method wrapper_CUDA_nll_loss_forward)
I’ve already tried to update the transformer library as specified in here but it didn’t work. When I run this code using just Trainer
class it works correctly, it throws this error only when trying to use this simple custom class.
What could be the cause of this error and how could I fix it?
*I would have added links to these but as a new user I can add only two links, sorry for the inconvenience.