Text classification training on long text

Hey everyone,

I’ve been looking into some methods of training and running inference on long text sequences. I’ve historically used BERT/RoBERTa and for the most part this has been great. Now that I’m trying to push the boundaries a bit, I’ve been looking into working with longer text sequences. I know that I could use models designed for longer text like Longformer or XLNet but the vast majority of my text is quite short, but every now and then I will get a long text sequence.

For token classification, I have been playing around with the a “stride” option as well as implementing my own implementation (before stride was available). Combined with the fact that NER is mostly using localized text to determine context for each token, I think that I don’t need to adjust my training set/techniques and can continue using the same model.

For text classification however, I think that it would benefit the model greatly to be able to receive the entire context of the document at once to assign a class. I have adapted my token classification chunking logic for text classification for inference but I’d love to train using the full length of the documents too.

Currently I’m taking any long documents, tokenizing them to check length, and then taking the first 256 tokens and the last 256 tokens. I have a few possible ideas of how to move forward:

  1. Use the stride function of the tokenizer during my own dataset prep and create multiple samples out of each document that’s longer than 512 tokens. These will all get fed into the model individually with their ground truth label. I could end up with 2-3 “variations” of the same document at different lengths in my dataset that have the same label.
  2. During tokenization of the data in the pre-processing step, add an “index” to each data entry and if a document is too long for tokenization, split it into multiple strides and give them the same index. Then modify the training loop so that during training, before I compute loss, average the logits of any predictions with the same “index”.

My question to you guys is, has anyone already done this or is it already built into HuggingFace and I just missed it? Also, will option 2) be significantly better than option 1)?