Logits function too slow

I am trying to construct a library for constrained generation. The goal hopfully is to skip generating text if there is only one possible next token.

The problem I am having is the logits function is way too slow to allow constrained generation to be of any use. Is there a way to speed up logits?

here is an example, that might work (my actual working code is in neuronx).

import torch
from transformers import LlamaForCausalLM, AutoTokenizer
import time

# Load the model and tokenizer
model_name = "meta-llama/Llama-2-7b-hf"
model = LlamaForCausalLM.from_pretrained(model_name,device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name)

import time
num_iterations = 10

start_time = time.time()
for _ in range(num_iterations):
    logits = generator.neuron_model.forward(torch.tensor(generator.encode(input_prompt), dtype=torch.long)).squeeze()
    softmax_probs = torch.nn.functional.softmax(logits, dim=-1)
    next_token_index = torch.multinomial(softmax_probs, 1).item()

end_time = time.time()
logits_time = end_time - start_time
print(f"Time taken for generating text using logits: {logits_time / num_iterations} seconds")

# Timing the generation using the generate_text method
start_time = time.time()
for _ in range(num_iterations):
    generated_text = generator.generate(input_prompt=input_prompt,max_length=10)

end_time = time.time()
generate_time = end_time - start_time
print(f"Time taken for generating text using generate_text: {generate_time / num_iterations} seconds")

here is the contrained genertion code

neuron_model = LlamaForSampling.from_pretrained(model_path + 'llama-2-7b-vicuna', batch_size=1, tp_degree=6, amp='bf16', context_length_estimate=[4000], n_positions=4000)
tokenizer = AutoTokenizer.from_pretrained(model_path + 'llama-2-7b-vicuna')

import torch
import torch.nn.functional as F
import numpy as np

class ConstrainedTextGenerator:
    def __init__(self, sequences, neuron_model, eos_token_id=2):
        self.neuron_model = neuron_model
        self.eos_token_id = self.encode("</s>")
        self.tree = self.preprocess(sequences)

    def preprocess(self, sequences):
        tree = {}
        for sequence in sequences:
            sequence_ids = self.encode(sequence)
            current_tree = tree
            for token in sequence_ids:
                token_item = token.item()  # Convert tensor to int
                if token_item not in current_tree:
                    current_tree[token_item] = {}
                current_tree = current_tree[token_item]
            # Add </s> to mark the end of each sequence
            eos_token = self.eos_token_id.item()  # Convert tensor to int
            if eos_token not in current_tree:
                current_tree[eos_token] = {}
        return tree

    def encode(self, text):
        # Replace this with your encoding logic, assuming it returns a list of token_ids
        return tokenizer.encode(text, add_special_tokens=False, return_tensors="pt")[0]

    def generate_text(self, input_prompt=""):
        input_ids_list = [[]]
        current_tree = self.tree

        # Encode the input prompt
        prompt_ids = self.encode(input_prompt)

        # Append prompt_ids to input_ids_list

        while True:
            # Check if there are multiple options at the current position
            if len(current_tree) > 1:
                # Get the indices of the available tokens
                available_indices = [list(current_tree.keys()).index(token) for token in current_tree.keys()]

                # Choose the token based on the softmax probabilities
                logits = self.neuron_model.forward(torch.tensor(input_ids_list, dtype=torch.long)).squeeze()
                softmax_probs = torch.nn.functional.softmax(logits[available_indices], dim=-1)

                # Sample from the softmax probabilities
                next_token_index = torch.multinomial(softmax_probs, 1).item()
                next_token = list(current_tree.keys())[available_indices[next_token_index]]
                # If there's only one option, skip forward and fill it in
                next_token = list(current_tree.keys())[0]


            # Check if it's the end of a sequence
            if next_token == self.eos_token_id.item():
                current_tree = current_tree.get(next_token, {})

        # Remove the empty sequence at the end, if any
        if not input_ids_list[-1]:

        input_ids = torch.tensor([token for seq in input_ids_list for token in seq], dtype=torch.long)
        generated_text = ' '.join(map(str, input_ids.tolist()))
        return input_ids