Implementing a Trainer with custom loss produces key error

Here’s my code:

import pandas as pd
from datasets import Dataset
from transformers import AutoTokenizer, PreTrainedTokenizer
from transformers import AutoModel, PreTrainedModel, AutoConfig, EarlyStoppingCallback
import torch.nn as nn
import torch
from transformers import TrainingArguments, Trainer
from functools import partial

class CustomModelForRegression(nn.Module):
    def __init__(self, hf_model_name):
        self.base_model = AutoModel.from_pretrained(hf_model_name)
        config = AutoConfig.from_pretrained(hf_model_name)
        self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(0.5)
        self.classifier = nn.Linear(config.hidden_size, 2)

    def forward(self, input_ids, attention_mask=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None):
        outputs = self.base_model(input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
        pooled_output = outputs[0][:, 0]  
        pooled_output = self.pre_classifier(pooled_output)  
        pooled_output = nn.ReLU()(pooled_output)  
        pooled_output = self.dropout(pooled_output)  
        logits = self.classifier(pooled_output)

        return logits

class CustomTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_fn = nn.MSELoss()

    def compute_loss(self, model, inputs, return_outputs=False):
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
            labels = None
        outputs = model(**inputs)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            loss = self.label_smoother(outputs, labels)
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
            loss = loss*loss
            loss = loss.mean()

        return (loss, outputs) if return_outputs else loss

def tokenize_function(tokenizer, examples):
    return tokenizer(examples["response"], truncation=True, padding="max_length")

def load_dataset():
    df = pd.read_csv("data/emp_df.csv", encoding="latin1", lineterminator='\n')
    dataset = Dataset.from_pandas(df)
    dataset = x: {"labels": [x["emp_mean"], x["emp_std"]]}, remove_columns=["emp_mean", "emp_std"])
    return dataset

def load_tokenizer_model(hf_model_name = "distilbert-base-uncased"):
    tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
    model = CustomModelForRegression(hf_model_name)
    return tokenizer, model

def compute_metrics(pred_labels):
    pred, labels = pred_labels
    loss_fn = nn.MSELoss()
    loss = loss_fn(pred.view(-1), labels.view(-1))
    return {"mse": loss.item()}

def train():
    tokenizer, model = load_tokenizer_model()
    dataset = load_dataset()
    tokenized_dataset =, tokenizer), batched=True)
    data_dict = tokenized_dataset.train_test_split(test_size=0.2)
    data_dict['eval'] = data_dict.pop('test')

    training_args = TrainingArguments(
    trainer = CustomTrainer(
                early_stopping_patience=50, early_stopping_threshold=0.01


if __name__ == "__main__":

Here’s the error:

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
{'loss': 0.0593, 'learning_rate': 4.999994871794872e-05, 'epoch': 0.03}                                             
{'loss': 0.001, 'learning_rate': 4.999989743589743e-05, 'epoch': 0.05}
{'loss': 0.0013, 'learning_rate': 4.9999846153846155e-05, 'epoch': 0.08}
{'loss': 0.041, 'learning_rate': 4.999979487179488e-05, 'epoch': 0.1}
{'loss': 0.0491, 'learning_rate': 4.999974358974359e-05, 'epoch': 0.13}
{'loss': 0.0, 'learning_rate': 4.999969230769231e-05, 'epoch': 0.15}
{'loss': 0.0161, 'learning_rate': 4.999964102564103e-05, 'epoch': 0.18}
{'loss': 0.024, 'learning_rate': 4.9999589743589745e-05, 'epoch': 0.21}
{'loss': 0.0244, 'learning_rate': 4.9999538461538467e-05, 'epoch': 0.23}
{'loss': 0.0095, 'learning_rate': 4.999948717948718e-05, 'epoch': 0.26}
{'loss': 0.0081, 'learning_rate': 4.99994358974359e-05, 'epoch': 0.28}
{'loss': 0.0222, 'learning_rate': 4.999938461538462e-05, 'epoch': 0.31}
{'loss': 0.0221, 'learning_rate': 4.9999333333333334e-05, 'epoch': 0.33}
{'loss': 0.0033, 'learning_rate': 4.9999282051282056e-05, 'epoch': 0.36}
{'loss': 0.0078, 'learning_rate': 4.999923076923077e-05, 'epoch': 0.38}
{'loss': 0.0043, 'learning_rate': 4.999917948717949e-05, 'epoch': 0.41}
{'loss': 0.0117, 'learning_rate': 4.999912820512821e-05, 'epoch': 0.44}
{'loss': 0.0295, 'learning_rate': 4.9999076923076924e-05, 'epoch': 0.46}
{'loss': 0.0013, 'learning_rate': 4.9999025641025646e-05, 'epoch': 0.49}
{'loss': 0.039, 'learning_rate': 4.999897435897436e-05, 'epoch': 0.51}
{'eval_runtime': 0.446, 'eval_samples_per_second': 174.884, 'eval_steps_per_second': 22.421, 'epoch': 0.51}
early stopping required metric_for_best_model, but did not find eval_mse so early stopping is disabled
{'loss': 0.0057, 'learning_rate': 4.999892307692308e-05, 'epoch': 0.54}
{'loss': 0.0162, 'learning_rate': 4.99988717948718e-05, 'epoch': 0.56}
{'loss': 0.0008, 'learning_rate': 4.9998820512820514e-05, 'epoch': 0.59}
{'loss': 0.0042, 'learning_rate': 4.9998769230769236e-05, 'epoch': 0.62}
{'loss': 0.0032, 'learning_rate': 4.999871794871795e-05, 'epoch': 0.64}
{'loss': 0.0063, 'learning_rate': 4.999866666666667e-05, 'epoch': 0.67}
{'loss': 0.0121, 'learning_rate': 4.999861538461539e-05, 'epoch': 0.69}
{'loss': 0.0096, 'learning_rate': 4.99985641025641e-05, 'epoch': 0.72}
{'loss': 0.0068, 'learning_rate': 4.999851282051282e-05, 'epoch': 0.74}
{'loss': 0.0102, 'learning_rate': 4.999846153846154e-05, 'epoch': 0.77}
{'loss': 0.0053, 'learning_rate': 4.999841025641026e-05, 'epoch': 0.79}
{'loss': 0.0019, 'learning_rate': 4.999835897435898e-05, 'epoch': 0.82}
{'loss': 0.0064, 'learning_rate': 4.999830769230769e-05, 'epoch': 0.85}
{'loss': 0.0052, 'learning_rate': 4.999825641025641e-05, 'epoch': 0.87}
{'loss': 0.0079, 'learning_rate': 4.999820512820513e-05, 'epoch': 0.9}
{'loss': 0.0053, 'learning_rate': 4.999815384615385e-05, 'epoch': 0.92}
{'loss': 0.0004, 'learning_rate': 4.999810256410257e-05, 'epoch': 0.95}
{'loss': 0.003, 'learning_rate': 4.999805128205128e-05, 'epoch': 0.97}
{'loss': 0.0031, 'learning_rate': 4.9998e-05, 'epoch': 1.0}
{'loss': 0.0005, 'learning_rate': 4.999794871794872e-05, 'epoch': 1.03}
{'eval_runtime': 0.4509, 'eval_samples_per_second': 172.994, 'eval_steps_per_second': 22.179, 'epoch': 1.03}
early stopping required metric_for_best_model, but did not find eval_mse so early stopping is disabled
{'loss': 0.0021, 'learning_rate': 4.999789743589744e-05, 'epoch': 1.05}
{'loss': 0.0043, 'learning_rate': 4.999784615384616e-05, 'epoch': 1.08}
{'loss': 0.0103, 'learning_rate': 4.999779487179487e-05, 'epoch': 1.1}
{'loss': 0.009, 'learning_rate': 4.999774358974359e-05, 'epoch': 1.13}
{'loss': 0.0001, 'learning_rate': 4.999769230769231e-05, 'epoch': 1.15}
{'loss': 0.0249, 'learning_rate': 4.999764102564103e-05, 'epoch': 1.18}
{'loss': 0.0018, 'learning_rate': 4.999758974358975e-05, 'epoch': 1.21}
{'loss': 0.0028, 'learning_rate': 4.999753846153846e-05, 'epoch': 1.23}
{'loss': 0.0015, 'learning_rate': 4.9997487179487184e-05, 'epoch': 1.26}
{'loss': 0.0013, 'learning_rate': 4.99974358974359e-05, 'epoch': 1.28}
{'loss': 0.0032, 'learning_rate': 4.999738461538462e-05, 'epoch': 1.31}
{'loss': 0.0028, 'learning_rate': 4.999733333333334e-05, 'epoch': 1.33}
{'loss': 0.0007, 'learning_rate': 4.999728205128205e-05, 'epoch': 1.36}
{'loss': 0.0011, 'learning_rate': 4.9997230769230774e-05, 'epoch': 1.38}
{'loss': 0.0028, 'learning_rate': 4.999717948717949e-05, 'epoch': 1.41}
{'loss': 0.0003, 'learning_rate': 4.9997128205128204e-05, 'epoch': 1.44}
{'loss': 0.0047, 'learning_rate': 4.9997076923076926e-05, 'epoch': 1.46}
{'loss': 0.0046, 'learning_rate': 4.999702564102564e-05, 'epoch': 1.49}
{'loss': 0.0051, 'learning_rate': 4.9996974358974364e-05, 'epoch': 1.51}
{'loss': 0.0016, 'learning_rate': 4.999692307692308e-05, 'epoch': 1.54}
{'eval_runtime': 0.4461, 'eval_samples_per_second': 174.841, 'eval_steps_per_second': 22.416, 'epoch': 1.54}
early stopping required metric_for_best_model, but did not find eval_mse so early stopping is disabled
{'loss': 0.0049, 'learning_rate': 4.9996871794871794e-05, 'epoch': 1.56}
{'loss': 0.0012, 'learning_rate': 4.9996820512820516e-05, 'epoch': 1.59}
{'loss': 0.0025, 'learning_rate': 4.999676923076924e-05, 'epoch': 1.62}
{'loss': 0.0023, 'learning_rate': 4.999671794871795e-05, 'epoch': 1.64}
{'loss': 0.004, 'learning_rate': 4.999666666666667e-05, 'epoch': 1.67}
{'loss': 0.0033, 'learning_rate': 4.9996615384615384e-05, 'epoch': 1.69}
{'loss': 0.0032, 'learning_rate': 4.9996564102564106e-05, 'epoch': 1.72}
{'loss': 0.0025, 'learning_rate': 4.999651282051283e-05, 'epoch': 1.74}
{'loss': 0.0004, 'learning_rate': 4.999646153846154e-05, 'epoch': 1.77}
{'loss': 0.0034, 'learning_rate': 4.999641025641026e-05, 'epoch': 1.79}
{'loss': 0.0037, 'learning_rate': 4.9996358974358973e-05, 'epoch': 1.82}
{'loss': 0.0036, 'learning_rate': 4.9996307692307695e-05, 'epoch': 1.85}
{'loss': 0.0036, 'learning_rate': 4.999625641025642e-05, 'epoch': 1.87}
{'loss': 0.0045, 'learning_rate': 4.999620512820513e-05, 'epoch': 1.9}
{'loss': 0.002, 'learning_rate': 4.999615384615385e-05, 'epoch': 1.92}
{'loss': 0.0069, 'learning_rate': 4.999610256410256e-05, 'epoch': 1.95}
{'loss': 0.0088, 'learning_rate': 4.9996051282051285e-05, 'epoch': 1.97}
{'loss': 0.001, 'learning_rate': 4.999600000000001e-05, 'epoch': 2.0}
{'loss': 0.0025, 'learning_rate': 4.999594871794872e-05, 'epoch': 2.03}
{'loss': 0.0048, 'learning_rate': 4.999589743589744e-05, 'epoch': 2.05}
{'eval_runtime': 0.4465, 'eval_samples_per_second': 174.692, 'eval_steps_per_second': 22.396, 'epoch': 2.05}
early stopping required metric_for_best_model, but did not find eval_mse so early stopping is disabled
{'loss': 0.0052, 'learning_rate': 4.999584615384615e-05, 'epoch': 2.08}
{'loss': 0.0102, 'learning_rate': 4.9995794871794875e-05, 'epoch': 2.1}
{'loss': 0.0004, 'learning_rate': 4.999574358974359e-05, 'epoch': 2.13}
{'loss': 0.0035, 'learning_rate': 4.999569230769231e-05, 'epoch': 2.15}
{'loss': 0.0034, 'learning_rate': 4.999564102564103e-05, 'epoch': 2.18}
{'loss': 0.0051, 'learning_rate': 4.999558974358974e-05, 'epoch': 2.21}
{'loss': 0.0007, 'learning_rate': 4.9995538461538465e-05, 'epoch': 2.23}
{'loss': 0.0043, 'learning_rate': 4.999548717948718e-05, 'epoch': 2.26}
{'loss': 0.0001, 'learning_rate': 4.99954358974359e-05, 'epoch': 2.28}
{'loss': 0.003, 'learning_rate': 4.999538461538462e-05, 'epoch': 2.31}
{'loss': 0.0036, 'learning_rate': 4.999533333333334e-05, 'epoch': 2.33}
{'loss': 0.0038, 'learning_rate': 4.9995282051282054e-05, 'epoch': 2.36}
{'loss': 0.0021, 'learning_rate': 4.999523076923077e-05, 'epoch': 2.38}
{'loss': 0.0014, 'learning_rate': 4.999517948717949e-05, 'epoch': 2.41}
{'loss': 0.0025, 'learning_rate': 4.999512820512821e-05, 'epoch': 2.44}
{'loss': 0.0052, 'learning_rate': 4.999507692307693e-05, 'epoch': 2.46}
{'loss': 0.0006, 'learning_rate': 4.9995025641025644e-05, 'epoch': 2.49}
{'loss': 0.0072, 'learning_rate': 4.999497435897436e-05, 'epoch': 2.51}
{'loss': 0.0012, 'learning_rate': 4.999492307692308e-05, 'epoch': 2.54}
{'loss': 0.0044, 'learning_rate': 4.9994871794871796e-05, 'epoch': 2.56}
{'eval_runtime': 0.4452, 'eval_samples_per_second': 175.207, 'eval_steps_per_second': 22.462, 'epoch': 2.56}
early stopping required metric_for_best_model, but did not find eval_mse so early stopping is disabled
Traceback (most recent call last):
  File "/workspaces/bot-empathy/", line 122, in <module>
  File "/workspaces/bot-empathy/", line 119, in train
  File "/usr/local/python/3.10.4/lib/python3.10/site-packages/transformers/", line 1662, in train
    return inner_training_loop(
  File "/usr/local/python/3.10.4/lib/python3.10/site-packages/transformers/", line 2006, in _inner_training_loop
    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
  File "/usr/local/python/3.10.4/lib/python3.10/site-packages/transformers/", line 2291, in _maybe_log_save_evaluate
    self._save_checkpoint(model, trial, metrics=metrics)
  File "/usr/local/python/3.10.4/lib/python3.10/site-packages/transformers/", line 2394, in _save_checkpoint
    metric_value = metrics[metric_to_check]
KeyError: 'eval_mse'

Why is the custom mse metric not available?

The Trainer was primarily built for the models in Transformers, and as such makes a certain number of assumptions (that you can find in the docs, scroll to the box in red). Here your model does not return a loss when labels are provided, which is fine for training since you overrode compute_loss (though I don’t see where you compute the loss there) but this is not the only place this behavior is used.

You model signature doesn’t contain any argument named labels, so the Trainer thinks your model does not accept any labels (righly so) and is thus not capable of doing evaluation, and thus does not even try to call your compute_metrics function. So the best way to fix this is just to have your model accept labels (with default None) and return the loss when labels are provided, like all the models from transformers, and it should make the Trainer happy :slight_smile:

Thanks so much, Sylvain! This makes sense. I changed the code for the model to accept labels and return the loss with logits in forward. I am computing the MSE loss in compute_loss but the forward method was not returning any loss. I fixed that and I was not being mindful of the types (sigh, Python) in a few places. Fixing that and adding your suggestion fixed the training loop and metric calculations. Bad case of not RTFM :smile:

Super grateful for your answers, as always, @sgugger!