I am looking to retrieve the final layer hidden states for every token, which includes those in the input sequence, while excluding the language model head (lm_head). How can I go about this?
To retrieve the final layer hidden states for every token, including those in the input sequence, while using model.generate(), you can follow these steps:
Set the correct configuration: Ensure that output_hidden_states=True and return_dict_in_generate=True in your GenerationConfig.
Use the forward() method: First, pass the input tokens through the model using model.forward() to get the hidden states of the input sequence.
Generate the output: Then, use model.generate() to generate the output tokens and capture their hidden states.
Here’s the modified code:
from transformers import MistralForCausalLM, AutoTokenizer, GenerationConfig
import torch
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
prompt = "[INST] What's your name? [/INST]"
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
# Step 1: Get hidden states for the input sequence
input_hidden_states = model(
input_ids,
output_hidden_states=True
).hidden_states[-1]
# Step 2: Generate with the model and get hidden states for generated tokens
generate_config = GenerationConfig(
temperature=1,
top_p=0.75,
top_k=40,
num_beams=4,
output_hidden_states=True
)
generation_output = model.generate(
input_ids,
generation_config=generate_config,
return_dict_in_generate=True,
)
with torch.no_grad():
# Get hidden states for the generated tokens
generated_hidden_states = generation_output.hidden_states[-1]
# Combine input and generated hidden states
hidden_states = torch.cat([input_hidden_states, generated_hidden_states], dim=1)
print(hidden_states.shape) # Output shape: (batch_size, total_tokens, hidden_size)
Explanation:
Get Input Hidden States: The model.forward() method processes the input tokens and returns their hidden states.
Generate and Capture Output Hidden States: The model.generate() method generates the output tokens and returns their hidden states when output_hidden_states=True and return_dict_in_generate=True.
Combine Hidden States: Finally, we combine the hidden states of the input and generated tokens into a single tensor, giving the hidden states for every token in the entire sequence.
By following these steps, you can obtain the hidden states for every token, including those in the input sequence, while using model.generate().