Metrics for Training Set in Trainer

I will expand on @sid8491’s answer. In my use case, I have to keep finetuning on multiple datasets from different languages. This way, we can keep track of metrics (loss, precision, recall, …) across different language datasets.

If anyone has suggestions for cleaner code, please do suggest.

from sklearn.metrics import precision_recall_fscore_support, accuracy_score, log_loss
from torch.nn import CrossEntropyLoss
import numpy as np
from copy import deepcopy
from transformers import TrainerCallback

lang = 'en'

class CustomCallback(TrainerCallback):
    
    def __init__(self, trainer) -> None:
        super().__init__()
        self._trainer = trainer
    
    def on_epoch_end(self, args, state, control, **kwargs):
        if control.should_evaluate:
            control_copy = deepcopy(control)
            self._trainer.evaluate(eval_dataset=self._trainer.train_dataset, metric_key_prefix="train@"+lang)
            return control_copy

def compute_metrics(pred):
    global num_labels
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    acc = accuracy_score(labels, preds)
    loss_fct = CrossEntropyLoss()
    logits = torch.tensor(pred.predictions)
    labels = torch.tensor(labels)
    loss = loss_fct(logits.view(-1, num_labels), labels.view(-1))
    return {
        'accuracy@'+lang: acc,
        'f1@'+lang: f1,
        'precision@'+lang: precision,
        'recall@'+lang: recall,
        'loss@'+lang: loss,
    }

training_args = TrainingArguments(
    evaluation_strategy="epoch",
    logging_strategy="epoch",
    learning_rate=1e-4,
    num_train_epochs=2,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    output_dir=MODEL_DIR+'_EN',
    overwrite_output_dir=True,
    remove_unused_columns=False,
    save_total_limit=1,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=en_train_dataset,
    eval_dataset=en_valid_dataset,
    compute_metrics=compute_metrics
)

trainer.add_callback(CustomCallback(trainer)) 

train_result = trainer.train()

trainer.evaluate(metric_key_prefix='test_en',
                eval_dataset=en_test_dataset)
2 Likes