I have an unbalanced dataset, so I’ve been trying to extend the Trainer class as specified on Trainer docs* and in this thread.
So I have the following code:
import torch import transformers from torch import nn from transformers import Trainer class CustomTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): labels = inputs.get("labels") # forward pass outputs = model(**inputs) logits = outputs.get('logits') # compute custom loss loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([0.2, 0.3, 0.5])) loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) return (loss, outputs) if return_outputs else loss
I then define the trainer and try to train it in the following way:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) tokenized_train = tokenized_train.with_format("torch") tokenized_val = tokenized_val.with_format("torch") model.to('cuda') training_args = TrainingArguments( output_dir="my_awesome_model", learning_rate=1e-5, per_device_train_batch_size=16, per_device_eval_batch_size=16, num_train_epochs=5, weight_decay=0.01, evaluation_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, #push_to_hub=True, ) trainer = CustomTrainer( model=model, args=training_args, train_dataset=tokenized_train, eval_dataset=tokenized_val, tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics, ) trainer.train()
As you can see I also set the dataset format to ‘torch’ as specified in another thread*, but I still get the following error at
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument weight in method wrapper_CUDA_nll_loss_forward)
I’ve already tried to update the transformer library as specified in here but it didn’t work. When I run this code using just
Trainer class it works correctly, it throws this error only when trying to use this simple custom class.
What could be the cause of this error and how could I fix it?
*I would have added links to these but as a new user I can add only two links, sorry for the inconvenience.