Trying to recreate `model.greedy_search()` for custom decoding of LLM output, but I am getting a different decoded output

I have recreated model.greedy_search() in 2 different ways, with the main difference being the size of input_ids.

Model Initialization

import torch
import transformers

# USER CONFIGURATIONS
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "EleutherAI/gpt-neo-1.3B"

bnb_config = transformers.BitsAndBytesConfig(
   load_in_4bit = True,
   bnb_4bit_quant_type = "nf4",
   bnb_4bit_use_double_quant = True,
   bnb_4bit_compute_dtype = torch.bfloat16
)

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, cache_dir = "./Models/")
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, quantization_config = bnb_config, cache_dir = "./Models/")

Generation Configurations

# GENERATION INPUTS
num_gen_tokens = 20
prompt = "ewrcewkr oewrkcl ewrkewr\n"
input_ids = tokenizer(prompt, return_tensors = "pt").to(device).input_ids.squeeze() # batched = False

# CONFIRM PRE-TRAINED CONFIGURATIONS
model.generation_config.pad_token_id = model.generation_config.eos_token_id
assert tokenizer.bos_token_id == model.generation_config.bos_token_id
assert tokenizer.eos_token_id == model.generation_config.eos_token_id
assert not ((model.generation_config.eos_token_id is None) ^ (model.generation_config.pad_token_id is None))

Variation 1: My code, where input_ids.shape = Size(num_input_tokens)

output_ids = input_ids.clone().detach()
model.eval()
with torch.no_grad():
    for _ in range(num_gen_tokens):
        if model.generation_config.eos_token_id is None or output_ids[-1] != model.generation_config.eos_token_id:
            outputs = model(output_ids)
            next_token_logits = outputs.logits[-1] # only consider the logits output based on last token of input
            next_tokens = next_token_logits.argmax(dim = -1).unsqueeze(dim = -1)
            output_ids = torch.cat((output_ids, next_tokens), dim = -1)

print(tokenizer.decode(output_ids))

Variation 2: My code, where input_ids.shape = Size(1, num_input_tokens)

output_ids = input_ids.clone().detach().unsqueeze(dim = 0)
model.eval()
with torch.no_grad():
    for _ in range(num_gen_tokens):
        if model.generation_config.eos_token_id is None or output_ids[:, -1] != model.generation_config.eos_token_id:
            outputs = model(output_ids)
            next_token_logits = outputs.logits[:, -1] # only consider the logits output based on last token of input
            next_tokens = next_token_logits.argmax(dim = -1).unsqueeze(dim = -1)
            output_ids = torch.cat((output_ids, next_tokens), dim = -1)

print(tokenizer.decode(output_ids.squeeze()))

Variation 3: HuggingFace API, where input_ids.shape = Size(1, num_input_tokens)

output = model.greedy_search(input_ids.clone().detach().unsqueeze(dim = 0), stopping_criteria = transformers.StoppingCriteriaList([transformers.MaxLengthCriteria(max_length = 20 + input_ids.size(dim = -1))])).squeeze()
print(tokenizer.decode(output))

In most cases, the generated tokens that are returned should be the same in all 3 methods utilized. However there are 2 cases I found that seems to violate this rule (note that only the prompt was changed, the rest of the variables remained the same):

Case 1: prompt = "ewrcewkr oewrkcl ewrkewr\n", Variation 2 seems to be the odd one out

  1. Variation 1’s output
ewrcewkr oewrkcl ewrkewr

I am a very simple person. I love to read, watch movies, and play video games
  1. Variation 2’s output
ewrcewkr oewrkcl ewrkewr

I am a very simple person. I am very easy going and I like to be around people
  1. Variation 3’s output
ewrcewkr oewrkcl ewrkewr

I am a very simple person. I love to read, watch movies, and play video games

Case 2: prompt = tokenizer.bos_token + "ewrcewkr oewrkcl ewrkewr\n", Variation 3 seems to be the odd one out

  1. Variation 1’s output
<|endoftext|>ewrcewkr oewrkcl ewrkewr

wewrcewkr oewrkcl ewrkewr

(a
  1. Variation 2’s output
<|endoftext|>ewrcewkr oewrkcl ewrkewr

wewrcewkr oewrkcl ewrkewr

(a
  1. Variation 3’s output
<|endoftext|>ewrcewkr oewrkcl ewrkewr

The following is a list of the most common words in the English language.

The most

I have 3 questions regarding the difference in outputs (as seen above):

  1. What should be the expected input_shape into model.forward()? Is it Size(1, num_input_tokens), or Size(num_input_tokens)? If the input is of Size(1, num_input_tokens), outputs.logits would have Size(torch.Size([1, num_output_tokens, num_tokenizer_tokens]). If the input is of Size(num_input_tokens), outputs.logits would have Size(torch.Size([num_output_tokens ** 2, num_tokenizer_tokens]).
  2. Does my code correctly model how the LLM decodes the output via greedy search?
  3. What is causing this difference in decoded output across all 3 methods used?

Thank you in advance.

By the way, I have also recreated model.beam_search(), and found no issues yet after testing for a while. Seems like the problem is only with model.greedy_search(), despite the algorithm being deterministic.

I figured out the answers to my own questions, so I’ll share my own findings here.

Q1 and Q2: Expected input_shape should be Size(batch_size, num_input_tokens), which also means that our input is permanently batched. Size(num_input_tokens) just does not work for other models like Llama2, due to how data is processed under the hood in model.forward(). As such, variation 1 is not recommended and variation 2 should be the way to go.

Q3: I tried to replicate the issue on Colab (I had originally used my local machine’s GPU), and found that case 1’s variation 2 should produce the same output as variation 3, and case 2’s variation 3 should produce the same output as variation 2. The above code snippets for variation 2 and 3 that I have shared are correct as it is. However, you might want to change unsqueeze(dim = -1) to view(-1, 1) instead, because during one of my trials one of the tensors had its last dimension expanded for some reason (which was wrong!).

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.