Logits from generate and model call different

I am trying to obtain the likelihood of a given sentence under a model.

TL,DR: The logits don’t match when generating a sentence and when evaluation it. Noting sentence = prompt + response (in pseudo code):

 model.generate(prompt).logits ! = model(sentence, labels=sentence).logits[prompt_len:]

I wrote down my example in full code:

import torch
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, GenerationConfig

torch.set_printoptions(linewidth=200)


def main():
    ###### SET UP ######
    # tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=model_dir)
    tokenizer.pad_token = tokenizer.eos_token

    # model
    config = AutoConfig.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, config=config, cache_dir=model_dir)

    # we will have prompt and response. sentence = prompt + response
    prompt_text = "Until 1988, Germany was "
    prompt_ids = tokenizer.encode(prompt_text, return_tensors="pt")
    prompt_length = prompt_ids.shape[1]

    generation_config = GenerationConfig(
        max_length=20,
        temperature=TEMPERATURE,
        renomalize_logits=True,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        return_dict_in_generate=True,
        output_scores=True,
        output_logits=True,
    )

    ###### LOGITS DURING GENERATION ######
    with torch.inference_mode():
        # Generate output sequence
        outputs = model.generate(
            inputs=prompt_ids,
            generation_config=generation_config,
        )

    # get logits: batch x tokens x vocab
    logits = torch.stack(outputs.logits, dim=1)

    # get sentence and response as ids and text
    sentence_ids = outputs.sequences

    ###### LOGITS EXISTING SENTENCE ######
    # forward pass to get logits
    with torch.inference_mode():
        outputs = model(sentence_ids, labels=sentence_ids)

    # get logits: batch x tokens x vocab
    logits_post_hoc = outputs.logits[:, prompt_length:, :]
    print(f"Shape of Logits from Generation: {logits.shape} Evaluation: {logits_post_hoc.shape}")

    ###### COMPARE LOGITS ######

    logits_are_close = torch.isclose(logits, logits_post_hoc, atol=1e-6, rtol=1e-6).all().item()
    print(f"Logits are close: {logits_are_close}")

    print("#### LOGITS GENERATION ####")
    n_tokens = 3
    n_logits = 10
    print(logits[:, :n_tokens, :n_logits])
    print("#### LOGITS EVALUATION ####")
    print(logits_post_hoc[:, :n_tokens, :n_logits])

    print("#### MAX PER TOKEN ####")
    print(f"\t Generation: {logits.max(-1).values}")
    print(f"\t Evaluation: {logits_post_hoc.max(-1).values}")

    l2_diff = torch.norm(logits - logits_post_hoc, p=2)
    l_inf_diff = torch.norm(logits - logits_post_hoc, p=float("inf"))
    print(f"L2 Norm: {l2_diff.item()} L_inf Norm: {l_inf_diff.item()}")

    print("VERY DIFFERENT!!!")


if __name__ == "__main__":
    MODEL_NAME = "openai-community/gpt2"
    TEMPERATURE = 1.0

    model_dir = "/path/to/models"

    main()

which generates this output:

Shape of Logits from Generation: torch.Size([1, 14, 50257]) Evaluation: torch.Size([1, 14, 50257])
Logits are close: False
#### LOGITS GENERATION ####
tensor([[[-103.2997, -104.4795, -106.1088, -105.8631, -106.4693, -109.1048, -107.7903, -103.7282, -105.0102, -105.2927],
         [ -96.2521,  -90.0448,  -91.3474,  -92.7398,  -94.4648,  -92.7509,  -90.8869,  -89.4783,  -95.3292,  -93.2915],
         [-117.1030, -113.4219, -113.2072, -117.1805, -115.8642, -117.2663, -115.1368, -115.4635, -113.1924, -115.8379]]])
#### LOGITS EVALUATION ####
tensor([[[ -96.2520,  -90.0447,  -91.3473,  -92.7398,  -94.4648,  -92.7508,  -90.8869,  -89.4783,  -95.3291,  -93.2914],
         [-117.1030, -113.4219, -113.2072, -117.1805, -115.8642, -117.2664, -115.1368, -115.4635, -113.1924, -115.8379],
         [-110.0100, -107.3938, -112.7266, -114.4252, -110.3775, -111.1679, -108.8828, -108.9165, -108.1152, -108.1860]]])
#### MAX PER TOKEN ####
         Generation: tensor([[ -90.5488,  -86.6507, -102.9849,  -96.2440,  -77.3176,  -13.5731,  -75.8808,  -64.5627, -111.4290,  -92.7544, -114.4539, -101.5360,  -95.7295, -117.2953]])
         Evaluation: tensor([[ -86.6507, -102.9849,  -96.2440,  -77.3176,  -13.5732,  -75.8808,  -64.5626, -111.4289,  -92.7544, -114.4540, -101.5359,  -95.7294, -117.2953,  -91.2336]])
L2 Norm: 25701.576171875 L_inf Norm: 72.83152770996094
VERY DIFFERENT!!!
1 Like