Sorry for the late reply, Actually I used logits[0] because the version I was using was passing both logits and labels in logits (maybe I was also misinterpreting them). I agree with you, the correct snippet should be with logits instead of logits[0] in the argmax.
pred_ids = torch.argmax(logits, dim=-1)