Currently following Prompt Tuning for CausalLM

I’m trying to compute the accuracy after each iteration correctly, but consistently achieve a higher(sometimes double) testing accuracy compared to training accuracy. The testing accuracy has also been greater than 1 in many instances.

Here is the modified training/testing loop:

for epoch in range(num_epochs):

model.train()

total_loss = 0

total_correct = 0

total_samples = 0

counter = 0

for step, batch in enumerate(tqdm(train_dataloader)):

if (counter >= 100):

break

batch = {k: v.to(device) for k, v in batch.items()}

outputs = model(**batch)

loss = outputs.loss

total_loss += loss.detach().float()

total_correct += torch.sum(torch.argmax(outputs.logits[:, 8:], dim=-1) == batch[“labels”]).item()

total_samples += batch[“labels”].size(0)

loss.backward()

optimizer.step()

lr_scheduler.step()

optimizer.zero_grad()

counter += 1`model.eval() eval_loss = 0 eval_preds = [] eval_correct = 0 eval_samples = 0 for step, batch in enumerate(tqdm(eval_dataloader)): batch = {k: v.to(device) for k, v in batch.items()} with torch.no_grad(): outputs = model(**batch) loss = outputs.loss eval_loss += loss.detach().float() eval_preds.extend( tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True) ) eval_correct += torch.sum(torch.argmax(outputs.logits[:, 8:], dim=-1) == batch["labels"]).item() eval_samples += batch["labels"].size(0) eval_epoch_loss = eval_loss / len(eval_dataloader) eval_ppl = torch.exp(eval_epoch_loss) train_epoch_loss = total_loss / len(train_dataloader) train_ppl = torch.exp(train_epoch_loss) train_accuracy = total_correct / total_samples eval_accuracy = eval_correct / eval_samples print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=} {train_accuracy=} {eval_accuracy=}")`

The shape of the logits are [16,264,30522] (batch size, prompt tuning virtual tokens + max_length, ?)

Both torch.argmax(outputs.logits[:, 8:], dim=-1) and batch[“labels”] have a [16,256] shape.

Would appreciate it if you could point out the logical error in the code. If there’s any additional info I should provide please let me know. Thanks!