Hi,
I am using a fine-tuned bert
model for text classification and, of course, I am getting very high accuracy scores. However, this masks a more nuanced reality when I look at the precision/recall metrics. See, for instance:
precision recall f1-score support
0 0.962 0.976 0.969 75532
1 0.516 0.395 0.448 4820
accuracy 0.941 80352
macro avg 0.739 0.686 0.708 80352
weighted avg 0.935 0.941 0.938 80352
My question is: for a given transformers architecture (say Bert
) do we know what we can do to improve either precision
or recall
(I know there is a tradeoff between the two). For instance is there evidence that oversampling the most infrequent class could improve precision
? What can be done here, beyond having more training data?
Thanks!
You can use class weights in your loss function if one label is more probable than the other or if your dataset is imbalanced. There’s plenty of tutorials on this so it shouldn’t be too hard to find.
1 Like
thanks @BramVanroy, very useful as usual! Do you have a link in mind that deals specifically with huggingface
and the Trainer api?
1 Like
@BramVanroy I think one interesting question is that, for instance with support vector machines in a text classification task we have clear hyperparameters C
or gamma
(the total budget for misclassified points and the parameter of the radial kernel) to tune.
For transformers, this seems less clear. It is not obvious which hyperparameters one can tune to improve performance that is. Any ideas?
@BramVanroy seaching over all the questions asked in the forum about this, most of them seem to refer to Trainer
I may be missing something, but it is not clear not me how I could use that documentation to deal with imbalance directly. There has to be a working example somewhere I hope?
Thanks!
Here is an example: How can I use class_weights when training? - #7 by nielsr (Or if you look for “class weight” on the forums you’ll find related examples.) This is the important part:
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([0.2, 0.3]))
where the tensor contains the class weights. Refer to the PyTorch documentation for more information about the class weights.
For hyperparameters to tune, as you say they are less distinct than more traditional algorithms. Typically you’d do a search for the optimal learning rate at least, potentially also including other arguments of the optimizer.
2 Likes