Using Trainer with custom model and custom dataset

I’m trying to train a custom model using the Trainer class, but I’m receiving the following error:

TypeError: forward() got an unexpected keyword argument 'labels'

Below is a minimal example of my implementation:

import torch
import torch.nn as nn
from datasets import Dataset
from transformers import Trainer, TrainingArguments


class MLP(nn.Module):
    def __init__(self, d_in, d_out):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(d_in, d_in)
        self.activation = nn.ReLU()
        self.fc2 = nn.Linear(d_in, d_out)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x, y):
        x = self.fc1(x)
        x = self.activation(x)
        logits = self.fc2(x)
        loss = self.criterion(logits, y)
        return {'loss': loss, 'logits': logits}


if __name__ == "__main__":

    # device = torch.device('hpu')
    device = 'cpu'
    n_classes = 3
    bs = 10
    n_features = 20

    training_args = TrainingArguments(
            output_dir='./checkpoint',
            num_train_epochs=3,
            per_device_train_batch_size=1,
            per_device_eval_batch_size=1,
            report_to='none',
            save_strategy='no',
            remove_unused_columns=False
        )
    
    model = MLP(n_features, n_classes)
    x  = torch.randn(bs, n_features)
    y = torch.arange(0, bs)[..., None]

    val_dataset = train_dataset = Dataset.from_dict({'input_ids': x, 'labels': y})

    trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset
        )

    trainer.train()

I’m not sure what I’m doing wrong. I believe that I’m meeting all of the requirements described here.

Hi,

It looks like the Trainer is passing the labels as keyword arguments to your model, i.e. model(labels=...). However, your forward method doesn’t accept a labels keyword argument.

You can fix it by updating your forward method:

def forward(self, input_ids, labels):
      x = self.fc1(input_ids)
      x = self.activation(x)
      logits = self.fc2(x)
      loss = self.criterion(logits, labels)
      return {'loss': loss, 'logits': logits}