Is There a Way to Improve Memory Usage When Using Identical `past_key_values` for All Samples in a Batch?

Hello Hugging Face Community!

I’m currently working on a project where I frequently batch process with identical past_key_values (pkv) for all samples in a batch. I calculate past_key_values once, then duplicate it as many times as I need to save on compute. The typical scenario involves having some input text and evaluating different continuations of that input text.

E.g. for the input text “The quick brown fox” I want to evaluate the log probs of each token in the two continuations “jumps over the lazy dog” and “dashes through the field”

Here is some example code to illustrate:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load the tokenizer and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("openchat/openchat_3.5", device=device)
# `pad_token` was not assigned by default - at least when I downloaded the model last
tokenizer.pad_token = tokenizer.special_tokens_map['additional_special_tokens'][1]
model = AutoModelForCausalLM.from_pretrained("openchat/openchat_3.5")
model.to(device)

# Function to batch encode texts with right padding - the tokenizer defaults to left padding
# since we are interested in token probabilities, not further generation, right padding is fine
def batch_encode(tokenizer, text_list, prompt_template, device):
    formatted_texts = [prompt_template.format(input_text=text) for text in text_list]
    input_ids = [tokenizer.encode(text, add_special_tokens=False) for text in formatted_texts]
    max_len = max(len(ids) for ids in input_ids)
    attention_mask = [[1] * len(ids) + [0] * (max_len - len(ids)) for ids in input_ids]
    input_ids = [ids + [tokenizer.pad_token_id] * (max_len - len(ids)) for ids in input_ids]
    return {
        "input_ids": torch.tensor(input_ids, dtype=torch.long).to(device),
        "attention_mask": torch.tensor(attention_mask, dtype=torch.long).to(device),
    }

# Define the prompt template
prompt_template = "<s>GPT4 Correct User: {input_text}<|end_of_turn|>GPT4 Correct Assistant:"

# Example usage with instruction
input_text = "Complete the following sentence: The quick brown fox"
input_ids = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt").to(device)

# Generate model outputs including past_key_values - we don't need hidden_states or attentions
output = model(input_ids=input_ids, use_cache=True, output_hidden_states=False, output_attentions=False)
past_key_values = output.past_key_values

# Define a function to duplicate past_key_values
def duplicate_pkv(pkv, num_repeats=2):
    return tuple(tuple(torch.cat([tensor] * num_repeats, dim=0) for tensor in layer) for layer in pkv)

# Continuation texts for the input
continuations = ["jumps over the lazy dog", "dashes through the field"]
num_continuations = len(continuations)

# Encode the continuations using the batch_encode function
encoded_continuations = batch_encode(tokenizer, continuations, prompt_template, device)

# Generate model outputs for the continuations
continuation_outputs = model(
    input_ids=encoded_continuations["input_ids"],
    # duplicate the input `past_key_values` `num_continuations` times
    # NOTE: this is where the memory usage can be very high
    past_key_values=duplicate_pkv(past_key_values, num_repeats=num_continuations),
    # we only need the logits, so set other outputs to False
    use_cache=False,
    output_hidden_states=False,
    output_attentions=False,
)

This approach becomes particularly memory-intensive when dealing with long input sequences or large batches. I’m curious to hear if there’s a more memory-efficient way to handle this.

Any insights or suggestions to optimize memory usage under these constraints would be greatly appreciated!