Hey @vblagoje I believe you forgot one tiny detail, other than that looks like a solid implementation! In a nutshell, these models return the probabilities for the next token, which means that logits[batch_idx, seq_idx, vocab_idx]
actually contains the logits corresponding to input_ids[batch_idx, seq_idx + 1]
. This implies that the pairing between the probabilities and the tokens has to be shifted by one
Here is a modified script:
from pprint import pprint
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
def to_tokens_and_logprobs(model, tokenizer, input_texts):
input_ids = tokenizer(input_texts, padding=True, return_tensors="pt").input_ids
outputs = model(input_ids)
probs = torch.log_softmax(outputs.logits, dim=-1).detach()
# collect the probability of the generated token -- probability at index 0 corresponds to the token at index 1
probs = probs[:, :-1, :]
input_ids = input_ids[:, 1:]
gen_probs = torch.gather(probs, 2, input_ids[:, :, None]).squeeze(-1)
batch = []
for input_sentence, input_probs in zip(input_ids, gen_probs):
text_sequence = []
for token, p in zip(input_sentence, input_probs):
if token not in tokenizer.all_special_ids:
text_sequence.append((tokenizer.decode(token), p.item()))
batch.append(text_sequence)
return batch
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.config.pad_token_id = model.config.eos_token_id
input_texts = ["One plus one is two", "Good morning", "Hello, how are you?"]
batch = to_tokens_and_logprobs(model, tokenizer, input_texts)
pprint(batch)
which yields
[[('One', -5.882715702056885),
(' plus', -9.785109519958496),
(' one', -0.7229145169258118),
(' is', -2.494063377380371),
(' two', -6.137458324432373)],
[('Good', -7.5790300369262695), (' morning', -1.826707124710083)],
[(',', -2.343151807785034),
(' how', -4.339702606201172),
(' are', -2.6824729442596436),
(' you', -0.4109247326850891),
('?', -1.8950778245925903)]]
Notice how high the logits for certain obvious tokens are, like morning
, you
, or ?
! Checking these tokens is always a good sanity check Also, look at the last sentence: there are no logits for the first token. If you want the logits for that token, you need to add extra padding on the left, so that the first text token is not the actual first token fed to the model.