Precision vs recall when using transformer models?

Hi,

I am using a fine-tuned bert model for text classification and, of course, I am getting very high accuracy scores. However, this masks a more nuanced reality when I look at the precision/recall metrics. See, for instance:

              precision    recall  f1-score   support

           0      0.962     0.976     0.969     75532
           1      0.516     0.395     0.448      4820

    accuracy                          0.941     80352
   macro avg      0.739     0.686     0.708     80352
weighted avg      0.935     0.941     0.938     80352

My question is: for a given transformers architecture (say Bert) do we know what we can do to improve either precision or recall (I know there is a tradeoff between the two). For instance is there evidence that oversampling the most infrequent class could improve precision? What can be done here, beyond having more training data?

Thanks!

You can use class weights in your loss function if one label is more probable than the other or if your dataset is imbalanced. There’s plenty of tutorials on this so it shouldn’t be too hard to find.

1 Like

thanks @BramVanroy, very useful as usual! Do you have a link in mind that deals specifically with huggingface and the Trainer api?

1 Like

@BramVanroy I think one interesting question is that, for instance with support vector machines in a text classification task we have clear hyperparameters C or gamma (the total budget for misclassified points and the parameter of the radial kernel) to tune.

For transformers, this seems less clear. It is not obvious which hyperparameters one can tune to improve performance that is. Any ideas?

@BramVanroy seaching over all the questions asked in the forum about this, most of them seem to refer to Trainer

I may be missing something, but it is not clear not me how I could use that documentation to deal with imbalance directly. There has to be a working example somewhere I hope?

Thanks!

Here is an example: How can I use class_weights when training? - #7 by nielsr (Or if you look for “class weight” on the forums you’ll find related examples.) This is the important part:

loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([0.2, 0.3]))

where the tensor contains the class weights. Refer to the PyTorch documentation for more information about the class weights.

For hyperparameters to tune, as you say they are less distinct than more traditional algorithms. Typically you’d do a search for the optimal learning rate at least, potentially also including other arguments of the optimizer.

1 Like