How to setup my custom compute_metrics correctly for Triplet Loss?

I’m trying to setup Triplet Loss for my pipeline, but it seems that it just won’t register.

Here is my pipeline setup

training_args = TrainingArguments(
    output_dir="./triplet_results",
    eval_strategy="steps",
    eval_steps=100,
    logging_steps=100,
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=1e-5,
    save_strategy="steps",
    save_steps=100,
    load_best_model_at_end=True,
    remove_unused_columns=False,
    metric_for_best_model="accuracy",
    greater_is_better=True,
)


trainer = TripletTrainer(
    model=model,
    args=training_args,
    train_dataset=train_triplet_dataset,
    eval_dataset=validation_triplet_dataset,
    compute_metrics=compute_triplet_metrics,
    data_collator=triplet_data_collator
)

And here is my implementation of self-contrastive learning.


def compute_triplet_metrics(eval_pred):
    anchor_embeddings, positive_embeddings, negative_embeddings = eval_pred.predictions

    anchor_emb = torch.from_numpy(anchor_embeddings)
    positive_emb = torch.from_numpy(positive_embeddings)
    negative_emb = torch.from_numpy(negative_embeddings)
    d_pos = F.pairwise_distance(anchor_emb, positive_emb)
    d_neg = F.pairwise_distance(anchor_emb, negative_emb)
    accuracy = torch.mean((d_pos < d_neg).float()).item()

    return {"accuracy": accuracy}


class TripletTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        anchor_inputs = inputs.get("anchor_input_values")
        positive_inputs = inputs.get("positive_input_values")
        negative_inputs = inputs.get("negative_input_values")

        anchor_outputs = model(anchor_inputs).last_hidden_state
        positive_outputs = model(positive_inputs).last_hidden_state
        negative_outputs = model(negative_inputs).last_hidden_state

        anchor_embedding = torch.mean(anchor_outputs, dim=1)
        positive_embedding = torch.mean(positive_outputs, dim=1)
        negative_embedding = torch.mean(negative_outputs, dim=1)

        loss_fct = nn.TripletMarginLoss(margin=1.0)
        loss = loss_fct(anchor_embedding, positive_embedding,
                        negative_embedding)

        class TripletOutput:
            def __init__(self, loss, anchor, positive, negative):
                self.loss = loss
                self.embeddings = (anchor, positive, negative)

        outputs = TripletOutput(loss, anchor_embedding,
                                positive_embedding, negative_embedding)
        return (outputs.loss, outputs) if return_outputs else outputs.loss

    def prediction_step(
        self,
        model,
        inputs,
        prediction_loss_only,
        ignore_keys=None,
    ):
        with torch.no_grad():
            loss, outputs = self.compute_loss(
                model, inputs, return_outputs=True)

        embeddings_tuple = outputs.embeddings

        predictions = (
            embeddings_tuple[0],
            embeddings_tuple[1],
            embeddings_tuple[2]
        )

        labels = None

        return (loss.detach(), predictions, labels)

The problem is that it does not even register – so I don’t think it’s the problem with, say, the eval prefix, the output is:

KeyError: “The metric_for_best_model training argument is set to ‘eval_accuracy’, which is not found in the evaluation metrics. The available evaluation metrics are: [‘eval_loss’]. Consider changing the metric_for_best_model via the T rainingArguments.”

1 Like

It seems that the return value of prediction_step needs to be returned as numpy.

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        with torch.no_grad():
            loss, out = self.compute_loss(model, inputs, return_outputs=True)
        a, p, n = out.embeddings
        preds = tuple(t.detach().cpu().numpy() for t in (a, p, n))  # numpy for metrics
        return (loss.detach(), preds, None)  # label_ids can be None