Specify Loss for Trainer / TrainingArguments

You can overwrite the compute_loss method of the Trainer, like so:

from torch import nn
from transformers import Trainer

class RegressionTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get('logits')
        loss_fct = MSELoss()
        loss = loss_fct(logits.squeeze(), labels.squeeze())
        return (loss, outputs) if return_outputs else loss

However, several models in the library have an attribute of their config called problem_type, which you can set to “regression”. In that case, you shouldn’t overwrite anything, and you can just use the default loss of the model.

3 Likes