Chapter 7, Token classification has a .compute_metric() function that uses argmax() instead of softmax(). I have attached it below. I’m wondering why does this not cause problems with gradient back-propagation? Argmax is non-differentiable, so I would assume there’d be problems, but example clearly works, and I must be missing something here.
link:
code:
import numpy as np
def compute_metrics(eval_preds):
logits, labels = eval_preds
predictions = np.argmax(logits, axis=-1)
# Remove ignored index (special tokens) and convert to labels
true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
true_predictions = [
    [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
return {
    "precision": all_metrics["overall_precision"],
    "recall": all_metrics["overall_recall"],
    "f1": all_metrics["overall_f1"],
    "accuracy": all_metrics["overall_accuracy"],
}