Iāve been working with the LLaMA2 model recently and noticed some behavior Iām confused about, probably due to a misunderstanding of mine.
Specifically, when I pad an input Iām getting different results for the loss and logits, even when I pass in the appropriate attention_mask.
Example:
model_id = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = LlamaTokenizer.from_pretrained(model_id, truncation_side='left', padding_side='right')
tokenizer.pad_token = tokenizer.eos_token
model = LlamaForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map={"":0})
input_prompt = "I've got a lovely bunch of coconuts do do do dooo"
input_prompt_tokenized = tokenizer(input_prompt, return_tensors="pt").to('cuda')
print(input_prompt_tokenized)
>>> {'input_ids': tensor([[ 1, 306, 29915, 345, 2355, 263, 12355, 873, 14928, 310,
1302, 535, 8842, 437, 437, 437, 437, 3634]],
device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
device='cuda:0')}
input_prompt_padded_tokenized = tokenizer(input_prompt, return_tensors="pt", padding="max_length", max_length=50).to('cuda')
print(input_prompt_padded_tokenized)
>>> {'input_ids': tensor([[ 1, 306, 29915, 345, 2355, 263, 12355, 873, 14928, 310,
1302, 535, 8842, 437, 437, 437, 437, 3634, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2]],
device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0]], device='cuda:0')}
# Mask labels so that the values corresponding to the padding token do not contribute to the loss
input_ids_masked = torch.zeros(input_prompt_padded_tokenized.input_ids.shape, dtype=torch.int64).to('cuda')
torch.where(input_prompt_padded_tokenized.input_ids == tokenizer.pad_token_id,
torch.tensor(-100, dtype=torch.int64),
input_prompt_padded_tokenized.input_ids,
out=input_ids_masked)
print(input_ids_masked)
>>> tensor([[ 1, 306, 29915, 345, 2355, 263, 12355, 873, 14928, 310,
1302, 535, 8842, 437, 437, 437, 437, 3634, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100]],
device='cuda:0')
# Calculate logits and loss from the model using unpadded input
loss1 = model(
input_prompt_tokenized.input_ids,
attention_mask=input_prompt_tokenized.attention_mask,
labels=input_prompt_tokenized.input_ids
).loss
print(f"Loss (no padding): {loss1}")
>>> Loss (no padding): 2.885075330734253
logits1 = model(
input_prompt_tokenized.input_ids,
attention_mask=input_prompt_tokenized.attention_mask,
labels=input_prompt_tokenized.input_ids
).logits
print(f"Logits (no padding): {logits1}")
>>> Logits (no padding): tensor([[[-0.1272, -0.3352, 0.4062, ..., 1.1426, 1.7002, 0.5195],
[-8.2344, -9.7266, -0.2708, ..., -3.2207, -7.9453, -2.6602],
[-3.4395, -1.9326, 3.8789, ..., -1.1035, -2.5391, 0.0756],
...,
[-3.2969, -2.8438, 8.3672, ..., 0.7686, -2.7871, -2.1738],
[-2.7637, -2.6953, 8.8516, ..., 0.6499, -2.8164, -2.2754],
[-2.4648, -1.4766, 8.0938, ..., -0.4836, -2.8984, -2.3613]]],
device='cuda:0', grad_fn=<ToCopyBackward0>)
# Calculate the logits and loss from the model using padded input (should be the same)
loss2 = model(
input_prompt_padded_tokenized.input_ids,
attention_mask=input_prompt_padded_tokenized.attention_mask,
labels=input_ids_masked
).loss
print(f"Loss (padding input & masking labels): {loss2}")
>>> Loss (padding input & masking labels): 2.8734283447265625
logits2 = model(
input_prompt_padded_tokenized.input_ids,
attention_mask=input_prompt_padded_tokenized.attention_mask,
labels=input_ids_masked
).logits
print(f"Logits (padding input & masking labels): {logits2}")
>>> Logits (padding input & masking labels): tensor([[[-1.2396e-01, -3.0322e-01, 3.9062e-01, ..., 1.1309e+00,
1.6719e+00, 4.6948e-01],
[-8.2578e+00, -1.0047e+01, -3.9233e-01, ..., -3.2793e+00,
-7.9570e+00, -2.7129e+00],
[-3.5547e+00, -1.9805e+00, 3.7363e+00, ..., -1.2891e+00,
-2.7754e+00, 5.6000e-03],
...,
[-4.9844e+00, 9.5781e+00, 5.0625e+00, ..., -2.4512e+00,
-3.1426e+00, -2.8789e+00],
[-5.7969e+00, 1.5977e+00, 4.9141e+00, ..., -3.9961e+00,
-2.0273e+00, -4.2500e+00],
[-5.6680e+00, 2.0645e+00, 4.7344e+00, ..., -3.2168e+00,
-1.3691e+00, -4.5547e+00]]], device='cuda:0',
grad_fn=<ToCopyBackward0>)
I expected both the losses and the logits to be the same (for logits, I just expect the entries corresponding to the non-padded tokens to agree). I thought that the attention_mask prevented the padded tokens from mattering, and the padding goes on the right so the original tokens canāt attend to the padding anyway!
Or so I thought. Itās not a large difference, but I expected it to be identical. What am I missing here?