Currently following Prompt Tuning for CausalLM
I’m trying to compute the accuracy after each iteration correctly, but consistently achieve a higher(sometimes double) testing accuracy compared to training accuracy. The testing accuracy has also been greater than 1 in many instances.
Here is the modified training/testing loop:
for epoch in range(num_epochs):
model.train()
total_loss = 0
total_correct = 0
total_samples = 0
counter = 0
for step, batch in enumerate(tqdm(train_dataloader)):
if (counter >= 100):
break
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
total_loss += loss.detach().float()
total_correct += torch.sum(torch.argmax(outputs.logits[:, 8:], dim=-1) == batch[“labels”]).item()
total_samples += batch[“labels”].size(0)
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
counter += 1model.eval() eval_loss = 0 eval_preds = [] eval_correct = 0 eval_samples = 0 for step, batch in enumerate(tqdm(eval_dataloader)): batch = {k: v.to(device) for k, v in batch.items()} with torch.no_grad(): outputs = model(**batch) loss = outputs.loss eval_loss += loss.detach().float() eval_preds.extend( tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True) ) eval_correct += torch.sum(torch.argmax(outputs.logits[:, 8:], dim=-1) == batch["labels"]).item() eval_samples += batch["labels"].size(0) eval_epoch_loss = eval_loss / len(eval_dataloader) eval_ppl = torch.exp(eval_epoch_loss) train_epoch_loss = total_loss / len(train_dataloader) train_ppl = torch.exp(train_epoch_loss) train_accuracy = total_correct / total_samples eval_accuracy = eval_correct / eval_samples print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=} {train_accuracy=} {eval_accuracy=}")
The shape of the logits are [16,264,30522] (batch size, prompt tuning virtual tokens + max_length, ?)
Both torch.argmax(outputs.logits[:, 8:], dim=-1) and batch[“labels”] have a [16,256] shape.
Would appreciate it if you could point out the logical error in the code. If there’s any additional info I should provide please let me know. Thanks!