I am trying to construct a library for constrained generation. The goal hopfully is to skip generating text if there is only one possible next token.
The problem I am having is the logits function is way too slow to allow constrained generation to be of any use. Is there a way to speed up logits?
here is an example, that might work (my actual working code is in neuronx).
import torch
from transformers import LlamaForCausalLM, AutoTokenizer
import time
# Load the model and tokenizer
model_name = "meta-llama/Llama-2-7b-hf"
model = LlamaForCausalLM.from_pretrained(model_name,device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name)
import time
num_iterations = 10
start_time = time.time()
for _ in range(num_iterations):
logits = generator.neuron_model.forward(torch.tensor(generator.encode(input_prompt), dtype=torch.long)).squeeze()
softmax_probs = torch.nn.functional.softmax(logits, dim=-1)
next_token_index = torch.multinomial(softmax_probs, 1).item()
end_time = time.time()
logits_time = end_time - start_time
print(f"Time taken for generating text using logits: {logits_time / num_iterations} seconds")
# Timing the generation using the generate_text method
start_time = time.time()
for _ in range(num_iterations):
generated_text = generator.generate(input_prompt=input_prompt,max_length=10)
end_time = time.time()
generate_time = end_time - start_time
print(f"Time taken for generating text using generate_text: {generate_time / num_iterations} seconds")
here is the contrained genertion code
neuron_model = LlamaForSampling.from_pretrained(model_path + 'llama-2-7b-vicuna', batch_size=1, tp_degree=6, amp='bf16', context_length_estimate=[4000], n_positions=4000)
neuron_model.to_neuron()
tokenizer = AutoTokenizer.from_pretrained(model_path + 'llama-2-7b-vicuna')
import torch
import torch.nn.functional as F
import numpy as np
class ConstrainedTextGenerator:
def __init__(self, sequences, neuron_model, eos_token_id=2):
self.neuron_model = neuron_model
self.eos_token_id = self.encode("</s>")
self.tree = self.preprocess(sequences)
def preprocess(self, sequences):
tree = {}
for sequence in sequences:
sequence_ids = self.encode(sequence)
current_tree = tree
for token in sequence_ids:
token_item = token.item() # Convert tensor to int
if token_item not in current_tree:
current_tree[token_item] = {}
current_tree = current_tree[token_item]
# Add </s> to mark the end of each sequence
eos_token = self.eos_token_id.item() # Convert tensor to int
if eos_token not in current_tree:
current_tree[eos_token] = {}
return tree
def encode(self, text):
# Replace this with your encoding logic, assuming it returns a list of token_ids
return tokenizer.encode(text, add_special_tokens=False, return_tensors="pt")[0]
def generate_text(self, input_prompt=""):
input_ids_list = [[]]
current_tree = self.tree
# Encode the input prompt
prompt_ids = self.encode(input_prompt)
# Append prompt_ids to input_ids_list
input_ids_list[0].extend(prompt_ids.tolist())
while True:
# Check if there are multiple options at the current position
if len(current_tree) > 1:
# Get the indices of the available tokens
available_indices = [list(current_tree.keys()).index(token) for token in current_tree.keys()]
# Choose the token based on the softmax probabilities
logits = self.neuron_model.forward(torch.tensor(input_ids_list, dtype=torch.long)).squeeze()
softmax_probs = torch.nn.functional.softmax(logits[available_indices], dim=-1)
# Sample from the softmax probabilities
next_token_index = torch.multinomial(softmax_probs, 1).item()
next_token = list(current_tree.keys())[available_indices[next_token_index]]
else:
# If there's only one option, skip forward and fill it in
next_token = list(current_tree.keys())[0]
input_ids_list[-1].append(next_token)
# Check if it's the end of a sequence
if next_token == self.eos_token_id.item():
break
else:
current_tree = current_tree.get(next_token, {})
# Remove the empty sequence at the end, if any
if not input_ids_list[-1]:
input_ids_list.pop()
input_ids = torch.tensor([token for seq in input_ids_list for token in seq], dtype=torch.long)
generated_text = ' '.join(map(str, input_ids.tolist()))
return input_ids