Why do I get different embeddings when I perform batch encoding in huggingface MT5 model?

I am trying to encode some text using HuggingFace’s mt5-base model. I am using the model as shown below

from transformers import MT5EncoderModel, AutoTokenizer

model = MT5EncoderModel.from_pretrained("google/mt5-base")
tokenizer = AutoTokenizer.from_pretrained("google/mt5-base")

def get_t5_embeddings(texts):
    last_hidden_state = model(input_ids=tokenizer(texts, return_tensors="pt", padding=True).input_ids).last_hidden_state
    pooled_sentence = torch.max(last_hidden_state, dim=1)
    return pooled_sentence[0].detach().numpy()

I was doing some experiments when I noticed that the same text had a low cosine similarity score with itself. I did some digging and realized that the model was returning very different embeddings if I did the encoding in batches. To validate this, I ran a small experiment that generated embeddings for Hello and a list of 10 Hellos incrementally. and checking the embeddings of the Hello and the first Hello in the list (both of which should be same).

for i in range(1, 10):
    print(i, (get_t5_embeddings(["Hello"])[0] == get_t5_embeddings(["Hello"]*i)[0]).sum())

This will return the number of values in the embeddings that match each other.
This was the result:

1 768
2 768
3 768
4 768
5 768
6 768
7 768
8 27
9 27

Every time I run it, I get mismatches if the batch size is more than 768.

Why am I getting different embeddings and how do I fix this?

1 Like

Take a look at the answer on Stackoverflow. python - Why do I get different embeddings when I perform batch encoding in huggingface MT5 model? - Stack Overflow

Hope it helps!

@huggingface devs, it might be good if FeatureExtractionPipeline have some sort of pooling techniques added:

from transformers import pipeline

pipe = pipeline(
    "feature-extraction", "google/mt5-base", 
    pooling_method="max"
)

At least for min/max/mean pooling, users might have been rolling out their custom versions as examplified by the SO answer. And for newer users, it might be at least an hours of being stymied before hitting on the right formulation of torch tensor manipulation.