Computing log probability of an arbitrary sequence given another sequence

I am exploring ways to calculate the the log probability of each token in a provided output sequence (not generated by the model) given an input sequence. For example, with input “I like” and output “ice cream”, I want to calculate p(“ice cream”|“I like”) using a LLaMA-based model and AutoModelForCausalLM. I’m considering two ways to do this:

  1. Concatenate the sequences, get the model output and calculate log_softmax() on the entire output.

  2. Incrementally add output tokens to the input sequence, tracking the log_softmax() of each added token.

Method 2 seems more intuitive to me but it is much slower for longer output, so I’m leaning towards Method 1. However, the two approaches return different log probs, even in the case that the output sequence contains only one token. Does anyone have any insights into why there is such a difference? Which method would you recommend over the other?

Here is my code for method 1:

name = "Yukang/LongAlpaca-7B"
tokenizer = AutoTokenizer.from_pretrained(name, device_map="auto", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(name, device_map="auto", trust_remote_code=True)
inputs = "I like"
outputs = "ice cream"

# Encode input and output
input_tokens = tokenizer.encode(inputs, add_special_tokens=False, return_tensors='pt')
output_tokens = tokenizer.encode(outputs, add_special_tokens=False, return_tensors='pt')

# Concatenate input and output tokens
tokens = torch.cat([input_tokens, output_tokens], dim=1)

# Get model predictions for the entire sequence at once
with torch.no_grad():
    outputs = model(tokens)
    logits = outputs.logits

log_sum = 0
range_index = range(input_tokens.shape[1] - 1, tokens.shape[1] - 1)
for i in range_index:
    past_tok, current_tok = i, i + 1
    token_logit = logits[0, past_tok, :]
    token_log_probs = torch.nn.functional.log_softmax(token_logit, dim=-1)
    log_token_prob = token_log_probs[tokens[0, current_tok]].item()
    log_sum += log_token_prob

    token = tokenizer.decode(tokens[:, current_tok])
    print(f"Token: {token}, Log Prob: {log_token_prob}")

print(f"Total Log Sum Probability: {log_sum}")

Method 2:

name = "Yukang/LongAlpaca-7B"
tokenizer = AutoTokenizer.from_pretrained(name, device_map="auto", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(name, device_map="auto", trust_remote_code=True)
inputs = "I like"
outputs = "ice cream"

input_tokens = tokenizer.encode(inputs, add_special_tokens=False, return_tensors="pt")
output_tokens = tokenizer.encode(outputs, add_special_tokens=False, return_tensors="pt")
input_tokens_updated = input_tokens.clone()
log_sum = 0

for i in range(output_tokens.shape[1]):
    # Predict with the given model
    with torch.no_grad():
        outputs = model(input_tokens_updated)
        logit_predictions = outputs.logits

    # Extract the log probability of the most recently added token
    last_token_logit = logit_predictions[0, -1, :]
    last_token_log_probs = torch.nn.functional.log_softmax(last_token_logit, dim=-1)
    log_token_prob = last_token_log_probs[output_tokens[0, i]].item()
    log_sum += log_token_prob

    # Incrementally add an output token to the current sequence
    last_token = tokenizer.decode(output_tokens[:, i])
    input_tokens_updated = torch.cat([input_tokens_updated, output_tokens[:, i:i+1]], dim=1)
    print([tokenizer.decode(token) for token in input_tokens_updated])
    print(f"Token: {last_token}, Log Prob: {log_token_prob}")
  print(f"Total Log Sum Probability: {log_sum}")

Thank you in advance!

Were you able to figure this out? Am experiencing very similar behaviour on my end and not sure why Method 1 and Method 2 differ so much. Thanks a lot in advance!