Performance of Llama-3.2-1B-Instruct on RWKU Utility General (MMLU) and Reasoning (Big Bench Hard) subset drops when batch size is increased to 4 from 1 (SOLVED)

I recently tried to implement an unlearning paper, during which I wrote the code for evaluating Llama 3.2 1B Instruct on the utility_general subset of the RWKU dataset (https://huggingface.co/datasets/jinzhuoran/RWKU). However, when I run the evaluation using batch size 1, the 5-shot performance of LLama-3.2-1B-Instruct on this utility_general is about 47.3, which is pretty close to the original benchmark. However, when I try to evaluate using a batch size of 4, the performance drops to 29.7

I don’t seem to understand what might be the reason for this.

The same thing occurs when I try to do a 3-shot evaluation on the Big Bench Hard dataset (utility_reason subset of RWKU); performance drops from 33.5 to 11.0 for BS 1 and 4, respectively.

I also used the prompt template from this repo https://huggingface.co/datasets/meta-llama/Llama-3.2-3B-Instruct-evals to make sure there is not issue with the prompt, but performance drop still happens.

When using BS = 4, the predictions are like

######################################################################################
The best answer is C C’
The best answer is C C’
The best answer is A A’
The best answer is D D’
[DEBUG][get_next_word_predictions] Row 0 - Sequence length (from attention mask): 756
[DEBUG][get_next_word_predictions] Row 0 - Logits for next token: [10.3125, 8.5625, 6.4375, 6.4375, 7.71875,
[DEBUG][get_next_word_predictions] Row 0 - Logits for candidate tokens: [9.625, 7.78125, 7.1875, 8.25]
[DEBUG][get_next_word_predictions] Probabilities: [0.66796875, 0.10546875, 0.058349609375, 0.1689453125]
[DEBUG][get_next_word_predictions] Row 0 - Predicted index in candidate tokens: 0
[DEBUG][get_next_word_predictions] Predicted token id: 362 | Decoded: ’ A’
[DEBUG][get_next_word_predictions] Row 1 - Sequence length (from attention mask): 780
[DEBUG][get_next_word_predictions] Row 1 - Logits for next token: [0.55078125, 8.125, 2.703125, 3.421875, 3.03125,
[DEBUG][get_next_word_predictions] Row 1 - Logits for candidate tokens: [3.96875, 3.796875, 3.296875, 5.3125]
[DEBUG][get_next_word_predictions] Probabilities: [0.162109375, 0.1357421875, 0.08251953125, 0.62109375]
[DEBUG][get_next_word_predictions] Row 1 - Predicted index in candidate tokens: 3
[DEBUG][get_next_word_predictions] Predicted token id: 423 | Decoded: ’ D’
[DEBUG][get_next_word_predictions] Row 2 - Sequence length (from attention mask): 802
[DEBUG][get_next_word_predictions] Row 2 - Logits for next token: [1.234375, -2.96875, 1.4765625, 2.609375, 1.8125
[DEBUG][get_next_word_predictions] Row 2 - Logits for candidate tokens: [21.375, 20.875, 20.625, 20.25]
[DEBUG][get_next_word_predictions] Probabilities: [0.416015625, 0.251953125, 0.1962890625, 0.134765625]
[DEBUG][get_next_word_predictions] Row 2 - Predicted index in candidate tokens: 0
[DEBUG][get_next_word_predictions] Predicted token id: 362 | Decoded: ’ A’
[DEBUG][get_next_word_predictions] Row 3 - Sequence length (from attention mask): 732
[DEBUG][get_next_word_predictions] Row 3 - Logits for next token: [8.0625, 6.03125, 6.6875, 2.984375, 3.40625, 7.84375,
[DEBUG][get_next_word_predictions] Row 3 - Logits for candidate tokens: [6.6875, 6.8125, 6.0625, 7.15625]
[DEBUG][get_next_word_predictions] Probabilities: [0.234375, 0.265625, 0.125, 0.375]
[DEBUG][get_next_word_predictions] Row 3 - Predicted index in candidate tokens: 3
[DEBUG][get_next_word_predictions] Predicted token id: 423 | Decoded: ’ D’

When using BS = 1,

######################################################################################
The best answer is C C’
[DEBUG][get_next_word_predictions] Row 0 - Sequence length (from attention mask): 780
[DEBUG][get_next_word_predictions] Row 0 - Logits for next token: [-1.3828125, -3.125, -0.439453125, 1.5625, 0.2158203125,
[DEBUG][get_next_word_predictions] Row 0 - Logits for candidate tokens: [22.25, 22.75, 23.625, 21.875]
[DEBUG][get_next_word_predictions] Probabilities: [0.13671875, 0.2265625, 0.54296875, 0.09423828125]
[DEBUG][get_next_word_predictions] Row 0 - Predicted index in candidate tokens: 2
[DEBUG][get_next_word_predictions] Predicted token id: 356 | Decoded: ’ C’
######################################################################################
[DEBUG][get_next_word_predictions] Logits: torch.Size([1, 802, 128256])
The best answer is A A’
[DEBUG][get_next_word_predictions] Row 0 - Sequence length (from attention mask): 802
[DEBUG][get_next_word_predictions] Row 0 - Logits for next token: [1.265625, -2.9375, 1.578125, 2.640625,
[DEBUG][get_next_word_predictions] Row 0 - Logits for candidate tokens: [21.375, 20.875, 20.625, 20.25]
[DEBUG][get_next_word_predictions] Probabilities: [0.416015625, 0.251953125, 0.1962890625, 0.134765625]
[DEBUG][get_next_word_predictions] Row 0 - Predicted index in candidate tokens: 0
[DEBUG][get_next_word_predictions] Predicted token id: 362 | Decoded: ’ A’
######################################################################################
[DEBUG][get_next_word_predictions] Logits: torch.Size([1, 732, 128256])
The best answer is D D’
[DEBUG][get_next_word_predictions] Row 0 - Sequence length (from attention mask): 732
[DEBUG][get_next_word_predictions] Row 0 - Logits for next token: [-1.8515625, -2.921875, 1.390625, 1.6640625,
[DEBUG][get_next_word_predictions] Row 0 - Logits for candidate tokens: [22.75, 23.375, 24.25, 24.625]
[DEBUG][get_next_word_predictions] Probabilities: [0.072265625, 0.134765625, 0.322265625, 0.470703125]
[DEBUG][get_next_word_predictions] Row 0 - Predicted index in candidate tokens: 3
[DEBUG][get_next_word_predictions] Predicted token id: 423 | Decoded: ’ D’

Strangely, the prediction logits are themselves changing so much that the predictions are changing drastically.

The prediction logic takes the logit values from the candidate indices and computes the softmax to predict the option [A, B, C, D]

row = logits[row_idx, length - 1]

choice_logits = row[list(candidate_token_ids)]

probs = torch.softmax(choice_logits, dim=-1)

pred_idx = int(torch.argmax(probs).item())

EDIT:: prompt slicing was not correctly implemented. all good now!! :upside_down_face:

1 Like