I am confused about how to determine the best embedding for a given entity from an LLM. If i run:
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, torch_dtype=torch.float16, trust_remote_code=True, device_map="auto")
outputs = model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states
For Llama2 and Mistral, there are 33 hidden states each of shape [batch_size, number_of_tokens, embedding_size]. Ultimately, when I think of “an” embedding, I am expecting something with the shape [batch_size, embedding_size]. (or just [1, embedding_size]). How do I do this conversion?
- Averaging and Pooling
The function get_pooling from here seems to suggest that there could be several ways of doing this. Here is a code snippet:
:param outputs: torch.Tensor. Model outputs (without pooling)
:param inputs: Dict. Model inputs
:param pooling_strategy: str. Pooling strategy ['cls', 'cls_avg', 'cls_max', 'last', 'avg', 'max', 'all', index]
:param padding_strategy: str. Padding strategy of tokenizers (`left` or `right`).
It can be obtained by `tokenizer.padding_side`.
"""
if pooling_strategy == 'cls':
outputs = outputs[:, 0]
elif pooling_strategy == 'cls_avg':
avg = torch.sum(
outputs * inputs["attention_mask"][:, :, None], dim=1) / torch.sum(inputs["attention_mask"])
outputs = (outputs[:, 0] + avg) / 2.0
elif pooling_strategy == 'cls_max':
maximum, _ = torch.max(outputs * inputs["attention_mask"][:, :, None], dim=1)
outputs = (outputs[:, 0] + maximum) / 2.0
elif pooling_strategy == 'last':
batch_size = inputs['input_ids'].shape[0]
sequence_lengths = -1 if padding_strategy == 'left' else inputs["attention_mask"].sum(dim=1) - 1
outputs = outputs[torch.arange(batch_size, device=outputs.device), sequence_lengths]
Why are the outputs being averaged with the attention_mask? The HF outputs shown in ‘ouputs’ above only give the output for the sequence_length so in that case can we just skip the attention_mask?
- Finally, given 2 above, what do current LLMs output as an embedding? For example OpenAI has an embeddings API that returns the embedding of a text. What approach are they using to calculate this embedding?
What would you consider as the embedding of a string from an LLM?