DataCollator for list of inputs?

Hi everyone,

I know that Longformer exists for long text sequences, but I’ve built a “sliding window” function for BERT inference that works pretty damn well. I tokenize the whole report, grab 512 token sections with overlap, run inference on each section, and average the logits before continuing normally. Only about 20% of my expected data are expected to be over the 512 token limit.

I’ve been looking into training my model with the intention of running inference over multiple segments of the text instead of just truncating to 512 and losing context from later in the text. I’m planning on implementing the following:

  1. data preprocessor that tokenizes the text, splits it into 512 token chunks (with some overlap) , and then puts it into a list object

  2. forward() function that takes a list of tokenized text instead of just a single dictionary. The forward() function will run inference on all objects in the list and then average the logits and calculate loss on the averaged logits.

The question I have is what I’ll have to do for the DataCollator. I took a quick look at the code, but I’m not sure if I’ll need to do anything particular/unique to make it work with my forward function. Does anyone have any advice on this?

Alternatively, has this been done before, or is there an easier way to achieve similar results? I considered pre-tokenizing my reports, splitting them, and then decoding them and then attach the same label to all of the sections. Would that yield “good enough” results?