Implementing Triplet loss in Vit

I’m currently working on training a ViT-MSN (Masked Siamese Network) model using triplet loss for a self-supervised or metric learning task. I’m using Hugging Face’s Trainer class for training, and I would like to integrate a miner (like those from pytorch-metric-learning) to improve the selection of positive and negative samples during training.

Here are the key challenges I’m facing:

  1. How to integrate triplet loss into a Hugging Face Trainer setup?
    I understand that I might need to subclass the Trainer or override the compute_loss method, but I’m not entirely sure how to best pass the anchor, positive, and negative embeddings properly when working with a ViT-MSN model.

  2. How to use a miner within the Hugging Face Trainer class?
    I’d like to use a miner like BatchHardMiner or TripletMarginMiner from pytorch-metric-learning. What’s the best practice for integrating such miners in the Hugging Face training loop, especially when customizing loss computation?

If anyone has done something similar or has suggestions/examples on how to:

structure the input data,

modify the Trainer to work with triplet loss and miners,

and properly align it with ViT-MSN architecture,

your input would be greatly appreciated!

Thanks in advance!

1 Like

I think overriding the function is the most reliable way.

https://stackoverflow.com/questions/66302371/how-to-specify-the-loss-function-when-finetuning-a-model-using-the-huggingface-t