Token classification on long sentences

I have fine-tuned a BERT model for named-entity recognition. My tasks usually involve long sentences (over 512 tokens is very common).

At the moment, I’m solving this by not calling the TokenClassificationPipeline, but instead using the max_length and stride parameters of the tokenizer to deal with splitting the sentences, and then aggregating the resulting token labels in post-processing. For this, I have to manually call the pyTorch model for each sub-sentence.

Is there a native (or at least more efficient) way to do this? I saw this was implemented in the QuestionAnsweringPipeline.

The reason I’m asking is that I’m trying to implement this as a serverless endpoint in Amazon SageMaker. And for that to run smoothly, the inference script should be as optimized as possible.