I am trying to obtain the likelihood of a given sentence under a model.
TL,DR: The logits don’t match when generating a sentence and when evaluation it. Noting sentence = prompt + response
(in pseudo code):
model.generate(prompt).logits ! = model(sentence, labels=sentence).logits[prompt_len:]
I wrote down my example in full code:
import torch
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, GenerationConfig
torch.set_printoptions(linewidth=200)
def main():
###### SET UP ######
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=model_dir)
tokenizer.pad_token = tokenizer.eos_token
# model
config = AutoConfig.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, config=config, cache_dir=model_dir)
# we will have prompt and response. sentence = prompt + response
prompt_text = "Until 1988, Germany was "
prompt_ids = tokenizer.encode(prompt_text, return_tensors="pt")
prompt_length = prompt_ids.shape[1]
generation_config = GenerationConfig(
max_length=20,
temperature=TEMPERATURE,
renomalize_logits=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=True,
output_logits=True,
)
###### LOGITS DURING GENERATION ######
with torch.inference_mode():
# Generate output sequence
outputs = model.generate(
inputs=prompt_ids,
generation_config=generation_config,
)
# get logits: batch x tokens x vocab
logits = torch.stack(outputs.logits, dim=1)
# get sentence and response as ids and text
sentence_ids = outputs.sequences
###### LOGITS EXISTING SENTENCE ######
# forward pass to get logits
with torch.inference_mode():
outputs = model(sentence_ids, labels=sentence_ids)
# get logits: batch x tokens x vocab
logits_post_hoc = outputs.logits[:, prompt_length:, :]
print(f"Shape of Logits from Generation: {logits.shape} Evaluation: {logits_post_hoc.shape}")
###### COMPARE LOGITS ######
logits_are_close = torch.isclose(logits, logits_post_hoc, atol=1e-6, rtol=1e-6).all().item()
print(f"Logits are close: {logits_are_close}")
print("#### LOGITS GENERATION ####")
n_tokens = 3
n_logits = 10
print(logits[:, :n_tokens, :n_logits])
print("#### LOGITS EVALUATION ####")
print(logits_post_hoc[:, :n_tokens, :n_logits])
print("#### MAX PER TOKEN ####")
print(f"\t Generation: {logits.max(-1).values}")
print(f"\t Evaluation: {logits_post_hoc.max(-1).values}")
l2_diff = torch.norm(logits - logits_post_hoc, p=2)
l_inf_diff = torch.norm(logits - logits_post_hoc, p=float("inf"))
print(f"L2 Norm: {l2_diff.item()} L_inf Norm: {l_inf_diff.item()}")
print("VERY DIFFERENT!!!")
if __name__ == "__main__":
MODEL_NAME = "openai-community/gpt2"
TEMPERATURE = 1.0
model_dir = "/path/to/models"
main()
which generates this output:
Shape of Logits from Generation: torch.Size([1, 14, 50257]) Evaluation: torch.Size([1, 14, 50257])
Logits are close: False
#### LOGITS GENERATION ####
tensor([[[-103.2997, -104.4795, -106.1088, -105.8631, -106.4693, -109.1048, -107.7903, -103.7282, -105.0102, -105.2927],
[ -96.2521, -90.0448, -91.3474, -92.7398, -94.4648, -92.7509, -90.8869, -89.4783, -95.3292, -93.2915],
[-117.1030, -113.4219, -113.2072, -117.1805, -115.8642, -117.2663, -115.1368, -115.4635, -113.1924, -115.8379]]])
#### LOGITS EVALUATION ####
tensor([[[ -96.2520, -90.0447, -91.3473, -92.7398, -94.4648, -92.7508, -90.8869, -89.4783, -95.3291, -93.2914],
[-117.1030, -113.4219, -113.2072, -117.1805, -115.8642, -117.2664, -115.1368, -115.4635, -113.1924, -115.8379],
[-110.0100, -107.3938, -112.7266, -114.4252, -110.3775, -111.1679, -108.8828, -108.9165, -108.1152, -108.1860]]])
#### MAX PER TOKEN ####
Generation: tensor([[ -90.5488, -86.6507, -102.9849, -96.2440, -77.3176, -13.5731, -75.8808, -64.5627, -111.4290, -92.7544, -114.4539, -101.5360, -95.7295, -117.2953]])
Evaluation: tensor([[ -86.6507, -102.9849, -96.2440, -77.3176, -13.5732, -75.8808, -64.5626, -111.4289, -92.7544, -114.4540, -101.5359, -95.7294, -117.2953, -91.2336]])
L2 Norm: 25701.576171875 L_inf Norm: 72.83152770996094
VERY DIFFERENT!!!