[Announcement] Generation: Get probabilities for generated output

Hey @vblagoje :wave: I believe you forgot one tiny detail, other than that looks like a solid implementation! In a nutshell, these models return the probabilities for the next token, which means that logits[batch_idx, seq_idx, vocab_idx] actually contains the logits corresponding to input_ids[batch_idx, seq_idx + 1]. This implies that the pairing between the probabilities and the tokens has to be shifted by one :slight_smile:

Here is a modified script:

from pprint import pprint
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


def to_tokens_and_logprobs(model, tokenizer, input_texts):
    input_ids = tokenizer(input_texts, padding=True, return_tensors="pt").input_ids
    outputs = model(input_ids)
    probs = torch.log_softmax(outputs.logits, dim=-1).detach()

    # collect the probability of the generated token -- probability at index 0 corresponds to the token at index 1
    probs = probs[:, :-1, :]
    input_ids = input_ids[:, 1:]
    gen_probs = torch.gather(probs, 2, input_ids[:, :, None]).squeeze(-1)

    batch = []
    for input_sentence, input_probs in zip(input_ids, gen_probs):
        text_sequence = []
        for token, p in zip(input_sentence, input_probs):
            if token not in tokenizer.all_special_ids:
                text_sequence.append((tokenizer.decode(token), p.item()))
        batch.append(text_sequence)
    return batch


tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained("gpt2")
model.config.pad_token_id = model.config.eos_token_id

input_texts = ["One plus one is two", "Good morning", "Hello, how are you?"]

batch = to_tokens_and_logprobs(model, tokenizer, input_texts)
pprint(batch)

which yields

[[('One', -5.882715702056885),
  (' plus', -9.785109519958496),
  (' one', -0.7229145169258118),
  (' is', -2.494063377380371),
  (' two', -6.137458324432373)],
 [('Good', -7.5790300369262695), (' morning', -1.826707124710083)],
 [(',', -2.343151807785034),
  (' how', -4.339702606201172),
  (' are', -2.6824729442596436),
  (' you', -0.4109247326850891),
  ('?', -1.8950778245925903)]]

Notice how high the logits for certain obvious tokens are, like morning, you, or ?! Checking these tokens is always a good sanity check :smiley: Also, look at the last sentence: there are no logits for the first token. If you want the logits for that token, you need to add extra padding on the left, so that the first text token is not the actual first token fed to the model.

6 Likes