LLaMA2 - tokenizer padding affecting logits (even with attention_mask)

I’ve been working with the LLaMA2 model recently and noticed some behavior I’m confused about, probably due to a misunderstanding of mine.

Specifically, when I pad an input I’m getting different results for the loss and logits, even when I pass in the appropriate attention_mask.

Example:

model_id = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = LlamaTokenizer.from_pretrained(model_id, truncation_side='left', padding_side='right')
tokenizer.pad_token = tokenizer.eos_token
model = LlamaForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map={"":0})

input_prompt = "I've got a lovely bunch of coconuts do do do dooo"
input_prompt_tokenized = tokenizer(input_prompt, return_tensors="pt").to('cuda')
print(input_prompt_tokenized)
>>> {'input_ids': tensor([[    1,   306, 29915,   345,  2355,   263, 12355,   873, 14928,   310,
          1302,   535,  8842,   437,   437,   437,   437,  3634]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='cuda:0')}

input_prompt_padded_tokenized = tokenizer(input_prompt, return_tensors="pt", padding="max_length", max_length=50).to('cuda')
print(input_prompt_padded_tokenized)
>>> {'input_ids': tensor([[    1,   306, 29915,   345,  2355,   263, 12355,   873, 14928,   310,
          1302,   535,  8842,   437,   437,   437,   437,  3634,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0]], device='cuda:0')}

# Mask labels so that the values corresponding to the padding token do not contribute to the loss
input_ids_masked = torch.zeros(input_prompt_padded_tokenized.input_ids.shape, dtype=torch.int64).to('cuda')
torch.where(input_prompt_padded_tokenized.input_ids == tokenizer.pad_token_id,
            torch.tensor(-100, dtype=torch.int64),
            input_prompt_padded_tokenized.input_ids,
            out=input_ids_masked)
print(input_ids_masked)
>>> tensor([[    1,   306, 29915,   345,  2355,   263, 12355,   873, 14928,   310,
          1302,   535,  8842,   437,   437,   437,   437,  3634,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100]],
       device='cuda:0')

# Calculate logits and loss from the model using unpadded input
loss1 = model(
    input_prompt_tokenized.input_ids,
    attention_mask=input_prompt_tokenized.attention_mask,
    labels=input_prompt_tokenized.input_ids
).loss
print(f"Loss (no padding): {loss1}")
>>> Loss (no padding): 2.885075330734253

logits1 = model(
    input_prompt_tokenized.input_ids,
    attention_mask=input_prompt_tokenized.attention_mask,
    labels=input_prompt_tokenized.input_ids
).logits
print(f"Logits (no padding): {logits1}")
>>> Logits (no padding): tensor([[[-0.1272, -0.3352,  0.4062,  ...,  1.1426,  1.7002,  0.5195],
         [-8.2344, -9.7266, -0.2708,  ..., -3.2207, -7.9453, -2.6602],
         [-3.4395, -1.9326,  3.8789,  ..., -1.1035, -2.5391,  0.0756],
         ...,
         [-3.2969, -2.8438,  8.3672,  ...,  0.7686, -2.7871, -2.1738],
         [-2.7637, -2.6953,  8.8516,  ...,  0.6499, -2.8164, -2.2754],
         [-2.4648, -1.4766,  8.0938,  ..., -0.4836, -2.8984, -2.3613]]],
       device='cuda:0', grad_fn=<ToCopyBackward0>)

# Calculate the logits and loss from the model using padded input (should be the same)
loss2 = model(
    input_prompt_padded_tokenized.input_ids,
    attention_mask=input_prompt_padded_tokenized.attention_mask,
    labels=input_ids_masked
).loss
print(f"Loss (padding input & masking labels): {loss2}")
>>> Loss (padding input & masking labels): 2.8734283447265625

logits2 = model(
    input_prompt_padded_tokenized.input_ids,
    attention_mask=input_prompt_padded_tokenized.attention_mask,
    labels=input_ids_masked
).logits
print(f"Logits (padding input & masking labels): {logits2}")
>>> Logits (padding input & masking labels): tensor([[[-1.2396e-01, -3.0322e-01,  3.9062e-01,  ...,  1.1309e+00,
           1.6719e+00,  4.6948e-01],
         [-8.2578e+00, -1.0047e+01, -3.9233e-01,  ..., -3.2793e+00,
          -7.9570e+00, -2.7129e+00],
         [-3.5547e+00, -1.9805e+00,  3.7363e+00,  ..., -1.2891e+00,
          -2.7754e+00,  5.6000e-03],
         ...,
         [-4.9844e+00,  9.5781e+00,  5.0625e+00,  ..., -2.4512e+00,
          -3.1426e+00, -2.8789e+00],
         [-5.7969e+00,  1.5977e+00,  4.9141e+00,  ..., -3.9961e+00,
          -2.0273e+00, -4.2500e+00],
         [-5.6680e+00,  2.0645e+00,  4.7344e+00,  ..., -3.2168e+00,
          -1.3691e+00, -4.5547e+00]]], device='cuda:0',
       grad_fn=<ToCopyBackward0>)

I expected both the losses and the logits to be the same (for logits, I just expect the entries corresponding to the non-padded tokens to agree). I thought that the attention_mask prevented the padded tokens from mattering, and the padding goes on the right so the original tokens can’t attend to the padding anyway!

Or so I thought. It’s not a large difference, but I expected it to be identical. What am I missing here?

2 Likes

Can it be because of the randomness of the generative model?

A forward pass through the model should be deterministic as far as I understand. The same input sequence should give the same logits used to get next-token prediction probabilities. The randomness I’m aware of in most LLMs is from how you decide to pick the next tokens (eg. greedy, top-k sampling, beam search, etc.) not from the forward pass through the network.

1 Like

I am seeing the same issue. I have tried to add padding on both left and right. None of them work. Did you figure out the root cause?