The code snippet from before looks at the prediction for the first token, which makes predictions
an int
rather than an iterable list
. You can get predictions for all the tokens by changing it like this:
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
predictions = probabilities.argmax(dim=-1).tolist()
print(predictions)