Hello Hugging Face Community!
I’m currently working on a project where I frequently batch process with identical past_key_values
(pkv) for all samples in a batch. I calculate past_key_values
once, then duplicate it as many times as I need to save on compute. The typical scenario involves having some input text and evaluating different continuations of that input text.
E.g. for the input text “The quick brown fox” I want to evaluate the log probs of each token in the two continuations “jumps over the lazy dog” and “dashes through the field”
Here is some example code to illustrate:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load the tokenizer and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("openchat/openchat_3.5", device=device)
# `pad_token` was not assigned by default - at least when I downloaded the model last
tokenizer.pad_token = tokenizer.special_tokens_map['additional_special_tokens'][1]
model = AutoModelForCausalLM.from_pretrained("openchat/openchat_3.5")
model.to(device)
# Function to batch encode texts with right padding - the tokenizer defaults to left padding
# since we are interested in token probabilities, not further generation, right padding is fine
def batch_encode(tokenizer, text_list, prompt_template, device):
formatted_texts = [prompt_template.format(input_text=text) for text in text_list]
input_ids = [tokenizer.encode(text, add_special_tokens=False) for text in formatted_texts]
max_len = max(len(ids) for ids in input_ids)
attention_mask = [[1] * len(ids) + [0] * (max_len - len(ids)) for ids in input_ids]
input_ids = [ids + [tokenizer.pad_token_id] * (max_len - len(ids)) for ids in input_ids]
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long).to(device),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long).to(device),
}
# Define the prompt template
prompt_template = "<s>GPT4 Correct User: {input_text}<|end_of_turn|>GPT4 Correct Assistant:"
# Example usage with instruction
input_text = "Complete the following sentence: The quick brown fox"
input_ids = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt").to(device)
# Generate model outputs including past_key_values - we don't need hidden_states or attentions
output = model(input_ids=input_ids, use_cache=True, output_hidden_states=False, output_attentions=False)
past_key_values = output.past_key_values
# Define a function to duplicate past_key_values
def duplicate_pkv(pkv, num_repeats=2):
return tuple(tuple(torch.cat([tensor] * num_repeats, dim=0) for tensor in layer) for layer in pkv)
# Continuation texts for the input
continuations = ["jumps over the lazy dog", "dashes through the field"]
num_continuations = len(continuations)
# Encode the continuations using the batch_encode function
encoded_continuations = batch_encode(tokenizer, continuations, prompt_template, device)
# Generate model outputs for the continuations
continuation_outputs = model(
input_ids=encoded_continuations["input_ids"],
# duplicate the input `past_key_values` `num_continuations` times
# NOTE: this is where the memory usage can be very high
past_key_values=duplicate_pkv(past_key_values, num_repeats=num_continuations),
# we only need the logits, so set other outputs to False
use_cache=False,
output_hidden_states=False,
output_attentions=False,
)
This approach becomes particularly memory-intensive when dealing with long input sequences or large batches. I’m curious to hear if there’s a more memory-efficient way to handle this.
Any insights or suggestions to optimize memory usage under these constraints would be greatly appreciated!