Sliding Window Approach for Multilabel Classification

I would like to adapt the sliding window approach in this example for multilabel classification:

The way I envision it:

  1. if example is over max token limit, the example is broken into parts using sliding window with overlap
  2. each chunk is run through model, receiving predicted labels
  3. all labels from each chunk are grouped back together, removing duplicates, leaving us with 1 predicted label
  4. that concatenated answer is evaluated with loss function
  5. backprop according to loss function

Is something like this possible? A challenge I see here is that all chunks from the parent example must fit inside of one batch, otherwise loss will be calculated with partial examples (i.e. with batch size = 4, but a given example is 5 chunks long, chunks 1-4 will be evaluated and chunk 5 will either be truncated or worse, pushed into the next batch). I am tempted to simply use Longformer or BigBird and truncate at 4096, but if my document is 10k tokens long, I run into the same issue.