To me .expand does not make sense because you do not want the logits for the input seq, you just need it for the class labels
To me .expand does not make sense because you do not want the logits for the input seq, you just need it for the class labels