Triplet (contrastive) loss for sequence embedding

I would like to use transformers to train a MIDI2Vec sequence encoder, such that I could map new MIDI sequences to a high dimensional space.

To do so, I was thinking to have transformers as an encoder (with BERT like structure), but then apply a contrastive triplet loss.

Meaning, for each “batch” datum I have 2 inputs, containing the same data, with a different augmentation.
Then I need to encode all inputs, together and apply a contrastive loss between every datum, its corresponding second datum, and a different example from one of the sets.

How should I go about doing this? what already written model parts should I use?
I was thinking:

  1. I need a different tokenizer, which is deterministic over midi, including pitch, velocity, time, etc.
  2. An embedding class that can embed all these features and concatenate them
  3. Use some stacked transformers architecture
  4. Somehow, apply the loss.

A possible alternative would be to train a masked language model over the music notes and velocity, but it will require multiple heads. One for classifying the pitch, one for regressing the velocity, etc