I am building a multi class text classifier using distilbert. I have approx 96 categories. I am looking for suggestions on how to improve the accuracy
You can train with more data, more or less epochs or a better model.
In addition to what @MattiLinnanvuori mentioned you may also have some your categories overlapping making it more difficult for the model to converge to your expected output.
Yes that is definitely the case. I have reduced the categories as well based on what is not needed anymore and covering 97% of the data still. I have also improved the data quality with the help of the data team. I am down to like 40 classes. improved my accuracy to 60%. Would changing a model from distilbert to something else drastically improve the model? If yes, any suggestions? I am training on user comments to the support team
How big is your training dataset? Consider increasing it if you are working with a specialised topic.
And are you using any metrics during training to see how the model is converging? Is there a difference between how it performs on training data vs. test data? If there is a difference, you can consider sampling your training data differently to make sure it better represents actual data distribution.
Also plot the errors and see if there is a pattern to the classes in these errors. Is the model more wrong for certain classes, is it the same pairs of classes that are confusing the model? Lookup the confusion matrix. That may help you figure out how to further tune your training data.
Another slightly different approach is to consider using something like GitHub - MaartenGr/BERTopic: Leveraging BERT and c-TF-IDF to create easily interpretable topics. to figure out what labels can be automatically discovered. Then see how that auto labelling matches your classification. Is it possible that the training data has misclassified classes?
train dataset - 840K
val dataset - 83K
test dataset - 9K
this is just taking this year’s data. I can definitely increase it.
I am looking at only loss for training dataset rest everything i am calculating on validation. Since my data is highly imbalanced i am using balanced_accuracy_score. Evaluation loss = 0.84 and Train loss = 0.77. Both have the same data distribution. I am doing the error analysis frm confusion matrix only and there is lot of confusion between this pair- user rejected receipts~user education.
By plotting the errors do you mean like with each other to see relationships?
You seem to have enough data.
I meant plotting rather loosely to mean to view the confusion matrix and see if the prediction accuracy is low for certain classes or high for others.
Split the classes you can predict with high accuracy from the ones with low. And then for the classes with low prediction accuracy try to find ways to augment that data to increase their accuracy. For example if the “user education” category is more consistently wrongly predicted, are there additional signals or data cleanup you could apply to the source data? You will also have to preprocess your actual data similarly to use the model.
If prediction quality is low even with this large number a training size, is it possible that some of the classes are overlapping?
By augmenting do you mean creating fake comments similar to those?
I am looking at the data to see what can be improved upon but to your question yes there is a lot of overlap within classes. it is just the way that the support team has defined them. there is some scope to redefine those and some flexibility with a few classes that even if they possess low accuracy it is fine. I feel even as a layman if i try and predict the category of a comment i feel it can fit into 2-3 categories.
I was thinking if i can pass the prediction of the first model into my second model along with the comment within the same training script that should improve the prediction right? (FYI - i have tier 1 model predicting with 85% accuracy and has 8 categories)
example - the tier 1 category prediction is ‘OFFERS’. Now OFFERS further has 5 tier 2 categories within itself. It should then be easier for this updated model to spit out from only 5 categories right?
Thoughts? Have you done anything like this before?
I have target categories at 3 levels - TIER 1,2,3. Creating separate models for each at the moment.
Augmenting meaning adding extra attributes, for example if you have other fields not currently included in the training data. Or you looked at the data that is incorrect an discover something in common for all of them. Adding that extra information could help the model train better.
Hierarchical category can also help if the data supports it. Just note that if there is confusion at the lower level categories it may also create bad predictions at the higher level if the lower level categories where there is confusion is across categories.
The model I am working on now has about 2000 categories, half of the training data is clean and I get 90% accuracy there, the remaining data from a different source, I am hitting 55%. I am reducing categories aggressively for the second dataset through tiering but haven’t yet figured out a way to drill down from the higher tiers to the lower ones.
One of the things I did to improve accuracy was to add the dataset source as an additional input parameter.
My source is the same but i don’t really have helpful additional attributes and i need to make prediction at the first comment but i can also split and see how my top categories are doing which are more frequent. I guess i will have to get my hands dirty and just dive in the data on my own maybe that will help.
have you ever applied hierarichal btw? also which model are you using? you think at this stage changing the model can also help?
I am using dilbert/bert. I haven’t yet gone far into hierarchical, still reading up.
In the past I haven’t been able to get significantly better results when I have swapped models, perhaps choice of models has not been great. It’s like the time spent in figuring out a different model isn’t worth the improvements if any. So I avoid unless I read somewhere or someone tells me to use a different model. Training models and retraining is a huge time sink for me.
Good luck with the next steps. If you work out a hierarchical model, would be interested to know how you approached it and your results,
yeah i am trying a few techniques. i will post my findings here by next week. Thanks for all the help though!