Using trasnformers without positional encoding for non-ordinal data

Hey all, I am wondering if anyone has attempted something along the lines of the topic title.

The question is, could you use transformers and MLM for non-ordinal data e.g. a set. Inherently the attention mechanism is fine with this, the problem becomes mask tokens. If you have multiple mask tokens/missing items from the set, then as they go through the model they will obtain the same state → not useful for predicting >=2 missing items.

Does anyone have any resources regarding this? My current proposed solution would be to add some noise to the mask tokens to break the symmetry, then to use Hungarian algorithm for assignment when calculating loss. An alternative solution might be to use a soft positional embedding applied only to the masked tokens, then Hungarian assignment again. I think HG allows for true set like nature to be retained, which then means its a problem of how do you go about breaking symmetry.

Thanks for reading and please let me know what you think and if you have any resources regarding this!

1 Like

In case anyone is following this, here are my thoughts so far: the hungarian assignment presents an issue (although perhaps solveable one). It may be biasing learning toward dataset global statistics as opposed to intra-sample statistics.

This may be best shown by simplified example.
Take these true sets: ABCE, BCD, BCE

Instead of learning complex rules like presence of D means there cannot be A or E and presence of E indicates there is a possibility of A, we could just learn the frequency of occurrence:
BCEAD
This way, when it comes to any sample, Hungarian assignment will give benefit of the doubt and find the best way to minimise loss. This is surprisingly effective! Infact, generated sets from the model re-capture the distribution almost identically (as you would expect). The catch is, if you plot the co-occurence matrices for real vs generated data, the generated data contains only noise. It has not learnt underlying patterns of the data, despite effectively minimizing loss.

My proposed solution to this would be to set quite a low temperature on the softmax in the loss function. This way, the model must be more confident in predictions, and its harder to get away with simply modelling the global data statistics.

NB. I should add the loss curves are quite curious. There are multiple plateus, early on the model does actually learn correlation and achieved quite high co-occurence correlation (>0.8) to real data. But the loss has a second wind after that and reverts to global statistics learning.

1 Like