I’m making some experiments on the probability of choosing a particular answer and I noticed that, even when using greedy decoding, the logits generated by model.generate(input_ids)
are very slightly different than the ones called with model(cat([input_ids, answer]))
with the same input.
Using Llama-3.1-8B-Instruct
, I get the following values.
input_ids = tensor([[128000, 16533, 279, 2768, 3488, 304, 264, 2478, 4339, 323, 449, 912, 37666, 13, 3639, 574, 279, 8250, 315, 578, 1050, 359, 2461, 315, 279, 21080, 555, 89315, 92898, 30, 578, 1050, 359, 2461, 315, 279, 21080, 555, 89315, 92898, 36513, 369]], device='cuda:0')
attention_mask = tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')
gen_target_logits = tensor([[18.1250, 20.8750, 20.8750, 16.1250, 14.2500, 15.6875, 23.8750, 22.7500, 20.2500, 22.1250, 20.7500, 19.6250, 21.7500, 29.5000, 18.5000, 19.8750, 20.5000, 23.6250, 20.6250, 22.2500]], device='cuda:0')
call_target_logits = tensor([[18.2500, 20.8750, 20.8750, 16.2500, 12.1875, 12.4375, 12.7500, 12.9375, 12.8125, 13.0625, 13.1250, 12.8750, 12.8125, 12.4375, 12.1875, 12.1250, 11.9375, 12.3125, 12.6250, 13.0000]], device='cuda:0')
As you can see, gen_target_logits
and call_target_logits
are almost equal. Each element is either the same or slightly lower in the version with generate
.
Why does this happen? Did I miss any regularisation done by generate
, or should I just put my hands up and say CUDA precision issues ?
Here’s the code I’m using to test this.
@torch.no_grad()
def test(model: LlamaForCausalLM, input_ids: LongTensor, attention_mask: BoolTensor):
generated = model.generate(
input_ids = input_ids,
attention_mask = attention_mask,
max_new_tokens = 20,
min_new_tokens = 20,
do_sample = False,
temperature = None,
use_cache = False,
top_p = None,
top_k = None,
output_logits = True,
return_dict_in_generate = True,
)
# Get the target logits from calling `model.generate`.
gen_logits = torch.stack(generated.logits, dim = 1)
gen_target_logits = gen_logits.max(dim = 2)[0]
# We can calculate the answer and its mask form these logits.
gen_answer = gen_logits.argmax(dim = 2)
gen_answer_mask = gen_answer != PAD
# Combine the input query and answer to pass to __call__.
combined_ids = torch.cat([input_ids, gen_answer], dim = 1)
combined_mask = torch.cat([attention_mask, gen_answer_mask], dim = 1)
# Call the model, and get the logits corresponding to the answer (shifted one position to the left).
w0 = input_ids.shape[0]
w1 = gen_answer.shape[1]
call_logits = model(combined_ids, attention_mask = combined_mask).logits[:, w0 - 1 : w0 + w1 - 1]
call_target_logits = call_logits.max(dim = 2)[0]
call_answer = call_logits.argmax(dim = 2)
# These two are exactly equal (which is good).
print(gen_answer)
print(call_target)
# These two should be exactly equal, and they are _almost_ equal.
print(gen_target_logits)
print(call_target_logits)