Text classification training on long text

Hey everyone,

I’ve been looking into some methods of training and running inference on long text sequences. I’ve historically used BERT/RoBERTa and for the most part this has been great. Now that I’m trying to push the boundaries a bit, I’ve been looking into working with longer text sequences. I know that I could use models designed for longer text like Longformer or XLNet but the vast majority of my text is quite short, but every now and then I will get a long text sequence.

For token classification, I have been playing around with the a “stride” option as well as implementing my own implementation (before stride was available). Combined with the fact that NER is mostly using localized text to determine context for each token, I think that I don’t need to adjust my training set/techniques and can continue using the same model.

For text classification however, I think that it would benefit the model greatly to be able to receive the entire context of the document at once to assign a class. I have adapted my token classification chunking logic for text classification for inference but I’d love to train using the full length of the documents too.

Currently I’m taking any long documents, tokenizing them to check length, and then taking the first 256 tokens and the last 256 tokens. I have a few possible ideas of how to move forward:

  1. Use the stride function of the tokenizer during my own dataset prep and create multiple samples out of each document that’s longer than 512 tokens. These will all get fed into the model individually with their ground truth label. I could end up with 2-3 “variations” of the same document at different lengths in my dataset that have the same label.
  2. During tokenization of the data in the pre-processing step, add an “index” to each data entry and if a document is too long for tokenization, split it into multiple strides and give them the same index. Then modify the training loop so that during training, before I compute loss, average the logits of any predictions with the same “index”.

My question to you guys is, has anyone already done this or is it already built into HuggingFace and I just missed it? Also, will option 2) be significantly better than option 1)?

If anyone else arrives here looking for an answer, my research has yielded the following observations:

There is already a library called BELT (GitHub - mim-solutions/bert_for_longer_texts: BERT classification model for processing texts longer than 512 tokens. Text is first divided into smaller chunks and after feeding them to BERT, intermediate results are pooled. The implementation allows fine-tuning.) that does option 2 above. However, this only works for binary classification. The issues they identify in this github repo speak to the problems with my option 1 above. I observed the same issues. If you take a document and pre-split it, you can end up with a scenario that the labels for the whole document don’t apply to the individual chunk. This is an issue both during training as well as inference. During training, you can confuse the model by providing labels that don’t properly line up with the chunk of text, and during inference you can end up with chunks of text whose logits meanpool into nothing, requiring maxpooling, but maxpooling can end up predicting too many labels.

I ended up settling on using Longformer for text classification, and I’m looking into converting my RoBERTa model to use LSG attention(GitHub - ccdv-ai/convert_checkpoint_to_lsg: Efficient Attention for Long Sequence Processing). I will report back if LSG proves to be more powerful than Longformer.

3 Likes