I’m working on the CLEF dataset for a research project, which contains ~480 users’ posts, each user has a varied range of posts (text) starting from 10 to 2000, and each user has an associated binary label 0/1 (whether the user is depressed).
So, basically, I’ve to take all the posts of a user and classify the label of the user.
I’m thinking about how to go about solving this problem with HF transformers (BERT or something maybe), if anyone has any suggestions or pointers (notebook, etc.), please share, would be really helpful. Thanks
also would like to cc: @lewtun and @sgugger
Hey @rasel277144, if I understand correctly you’d like to classify whether whether a user is “depressed” based on their posts?
In this case, you could concatenate all the user posts and treat it as a standard classification problem and Sylvain has created a nice tutorial for this task here.
Having said that, you will probably run into limitations with the maximum context size of models like BERT (typically just a few paragraphs), so you might want to see if models like BigBird or LongFormer can help as their context size is 8x that of BERT. If that’s still not sufficient, you might want to adapt some of the suggestions in this thread to text classification (e.g. you could create an embedding for each user post, average the embeddings, and then use those embeddings for a simple logistic regression classifier)
PS I put “depressed” in quotes because I assume this is not a phenomenon we can hope to capture accurately from written text alone. I also suggest treading very carefully in this domain as there’s plenty of public examples where using NLP to diagnose patient well-being leads to bad outcomes.
Wow, @lewtun! thank you so much for all the amazing pointers you’ve given me, it clears out many of my confusions regarding how do I go about tackling this large sequence size issue. It’s very helpful.
btw, I totally agree with the last point of yours, personally, I also don’t feel Natural Language is a good signal in capturing such a phenomenon, I’m actually trying to implement a paper where they’ve used such data for this classification, they’ve used CNN as feature extractor, so I’m trying to do it with transformer-based models.
Good luck @rasel277144 ! One last thing I forgot to mention: if you decide to adopt the embedding-based approach, you will probably get better results if you use
sentence-transformers which use a siamese network to obtain “better” embeddings that what you’ll typically get with vanilla BERT and just extracting the last hidden states. You can find a performance comparison here across various tasks / datasets, and many of the models are available on the Hub
Thank you so much @lewtun! I didn’t know about
sentence-transformers, I will definitely check it out.
I am a bit confused about this part that you suggested above-
“e.g. you could create an embedding for each user post, average the embeddings, and then use those embeddings for a simple logistic regression classifier.”
So, let’s say the first user has 10 posts, then I use a sentence-transformer and get embeddings for each post (or each sentence of all posts of that user) individually, and then repeat it for all users. But then before feeding into the classifier model I’ve to merge all the post embeddings of each user, so for the first user 10 embeddings, concatenation would probably make it too big, so are you suggesting I should take the average of the 10 embedding vectors?
Yes, that’s the idea The goal is to create in some sense a single embedding for each user, so the quick 'n dirty way to do that is by averaging the embedding vectors of each post.
One drawback with this approach is that it will average out subtle correlations associated with each post and so you might need to try a weighted average (e.g. based on the post length or other potentially interesting criteria) to get good results.
Also, my experience has generally been that you get better results with fine-tuning, so I would also compare the embedding approach against the simplest thing of concatenating all the posts and truncating anything longer than the context size of the model (BERT, BigBird etc).
One other thing regarding long-contexts: you can also try using a sliding window approach as suggested in this thread.
This is actually how the question-answering pipelines are able to handle long contexts and I think the above suggestion could also work nicely in your case
Thank you so much @lewtun for all your amazing suggestions, it helps a lot!