Using weights with transformers huggingface - running on GPUs

I came across this tutorial which performs Text classification with the Longformer. I came across this two links - one and two which talk about using class weights when the data is unbalanced.

# instantiate the trainer class and check for available devices
trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_data,
    eval_dataset=test_data
)

The code from the first link is above

The code from the next 2 links is below

from torch import nn
from transformers import Trainer


class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # compute custom loss (suppose one has 3 labels with different weights)
        loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

I modified that code below as I am going to use GPU

device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'
device

Returns 'cuda'

#class weights

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # compute custom loss (suppose one has 3 labels with different weights)
        loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 0.5243])).to(device)
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)).to(device)
        return (loss, outputs) if return_outputs else loss

my question) I added part .to(device) to the below lines. Is that the correct way to modify the code so the entire code runs on GPUs without problems?

loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 0.5243])).to(device)
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)).to(device)
1 Like