Sorry for the late reply, Actually I used logits[0]
because the version I was using was passing both logits and labels in logits (maybe I was also misinterpreting them). I agree with you, the correct snippet should be with logits
instead of logits[0]
in the argmax.
pred_ids = torch.argmax(logits, dim=-1)