Function that breaks up sentence embeddings efficiently for batch inference

Say I’m performing inference and I set my max batch size at 200 inference examples, and max tokens to 100. Obviously, the amount of tokens varies greatly from training example to train example. Is there a function that can simply take in my inference examples, and then output the my inference examples so they can be inferenced in the most efficient way possible? I’ll given an example below. For simplicity, I’ll set max batch size to 2 and max tokens to 8. xi will Deonte some inference example, and the arbitrary integer values in xi will represent tokens.

x1 = [1,9,6,6]
x2 = [2,5,2,3,5,1,2,7,2,6]
x3 = [2,5,6]
x4 = [2,4,2,5,8,3]

Below would be the output
output = [ [x2, x4], [x1, x3] ]
and again just to show the actual vectors, which are now masked appropriately, where 0 represents mask tokens
output = [ [ [2,5,2,3,5,1,2,7], [2,4,2,5,8,3, 0, 0] ], [[1,9,6,6], [2,5,6, 0]] ]

As you can see, the function would put the smaller inference examples together, thus minimizing the amount of mask tokens and decreasing the operations during forward propagation. I’m sure I could code this, but it would be awesome if this exist somewhere.