Implement Multiple Negatives Ranking Loss

Hi!

I was wondering if there are any resources or guides on how to get started on MNRL loss in transformers. I have had a good look at the Trainer class, so subclassing that and extending it seems easy.

However I cannot find anything on triplet loss, or MNRL to use as a starting point which would be great to see!

My dataset only contains positive pairs (doc, query). This is very similar to what the Sentence Transfomers library offers, but it has drawbacks on the fact loss cannot be tracked easily and logging to 3rd party systems like W&B is also not implemented.

from transformers import Trainer

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        
        loss = multiple_negatives_ranking_loss(outputs.last_hidden_state)
        
        if return_outputs:
            return loss, outputs
        return loss