Accuracy is greater than 1 for Prompt Tuning

Currently following Prompt Tuning for CausalLM

I’m trying to compute the accuracy after each iteration correctly, but consistently achieve a higher(sometimes double) testing accuracy compared to training accuracy. The testing accuracy has also been greater than 1 in many instances.

Here is the modified training/testing loop:

for epoch in range(num_epochs):
model.train()
total_loss = 0
total_correct = 0
total_samples = 0
counter = 0
for step, batch in enumerate(tqdm(train_dataloader)):
if (counter >= 100):
break
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
total_loss += loss.detach().float()
total_correct += torch.sum(torch.argmax(outputs.logits[:, 8:], dim=-1) == batch[“labels”]).item()
total_samples += batch[“labels”].size(0)
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
counter += 1

model.eval()
eval_loss = 0
eval_preds = []
eval_correct = 0
eval_samples = 0
for step, batch in enumerate(tqdm(eval_dataloader)):
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)
    loss = outputs.loss
    eval_loss += loss.detach().float()
    eval_preds.extend(
        tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
    )
    eval_correct += torch.sum(torch.argmax(outputs.logits[:, 8:], dim=-1) == batch["labels"]).item()
    eval_samples += batch["labels"].size(0)

eval_epoch_loss = eval_loss / len(eval_dataloader)
eval_ppl = torch.exp(eval_epoch_loss)
train_epoch_loss = total_loss / len(train_dataloader)
train_ppl = torch.exp(train_epoch_loss)
train_accuracy = total_correct / total_samples
eval_accuracy = eval_correct / eval_samples
print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=} {train_accuracy=} {eval_accuracy=}")

The shape of the logits are [16,264,30522] (batch size, prompt tuning virtual tokens + max_length, ?)
Both torch.argmax(outputs.logits[:, 8:], dim=-1) and batch[“labels”] have a [16,256] shape.

Would appreciate it if you could point out the logical error in the code. If there’s any additional info I should provide please let me know. Thanks!

Accuracy = Winners / All Either All is too small or winners is too big.
If batch[‘labels’].size(0) is equal to your batch_size then the problem is likely in the line above. I would break that up into smaller steps to debug why it is bigger than it should be.