Embeddings from llama2


I am trying to get sentence embeddings from a llama2 model. I tried using the feature extraction pipeline and expect the output to be a tensor of size (seq_len, embedding_dim). but it is a list(list(list)))

Seems like it is of size (seq_len, vocab_size)? Could you please help me understand why?

Or what is the right way to get a sentence embedding for a Llama model. Thanks!

from transformers import LlamaTokenizer, LlamaForCausalLM, pipeline
sentences = ["This is me", "A 2nd sentence"]
model_base_name = "meta-llama/Llama-2-7b-hf"
model = LlamaForCausalLM.from_pretrained(model_base_name)
tokenizer = LlamaTokenizer.from_pretrained(model_base_name)
feature_extraction = pipeline('feature-extraction', model=model, tokenizer=tokenizer)
embeddings = feature_extraction(sentences) # output should be of size (seq_len, embedding_dim) but is of size (seq_len, vocab_size)

(Pdb) len(embeddings[0][0][0])

(Pdb) len(embeddings[0][0])

(Pdb) len(embeddings[0])


I have the same situation as mentioned by Saaira.
Does anyone have any solutions or explanation for this?

Hi @Saaira,

Just found the why and how for this question.

The pipeline generally returns the first available tensor, which refers to the logits in the Llama model

  1. pipeline source code
  2. Llama doc

Instead of using the pipeline for efficiency and neat codes,

model(torch.IntTensor([tokenizer(sentences)['input_ids'][0]]),return_dict=True, output_hidden_states=True)['hidden_states']

you can get the hidden states from all the layers (including the embedding layer) for each token,
you will get for the first sentence

len(embeddings['hidden_states']), embeddings['hidden_states'][0].shape
(33, torch.Size([1, 4, 4096]))

Thanks @jasperlp Very helpful!

What did you find to be the best pooling strategy with llama embeddings?

Haven’t tried too much on that. I suppose this is tasks by task thing

You can use AnglE-LLaMA to extract sentence embedding from LLaMA/LLaMA2: GitHub - SeanLee97/AnglE: Angle-optimized Text Embeddings | 🔥 New SOTA

1 Like

What is the max input length?