Training with varying lengths of sequences

My question is about padding and batching. For context I am training models with the CasualLM task (although I’ve had a similar question with other tasks). My dataset has sequences which vary quite a bit in length, from just a few tokens to so many tokens that I must truncate for the model’s max length.

My code is a bit more involved, but the basic idea is as follows:

from datasets import load_from_disk
from transformers import AutoTokenizer

dataset_location = "path_to_dataset"

dataset = load_from_disk(dataset_location)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1.4b-deduped")
tokenizer.pad_token = tokenizer.eos_token

def preprocess(example):
    return tokenizer(example['text'], max_length=2048, padding='max_length', truncation=True)

tokenized_dataset =, batched=True, remove_columns=['text'])

By padding to the max length, I am creating a very large dataset and the batches are always the maximum size the model allows. This increases my training time and memory requirements I believe.

So, main question, is there a way/should I be creating batches that vary in length according to the max length of the batch’s sequences? If so, how do I do that? How do I keep the batches that were truncated together the same batch that gets trained together?

I have gotten around this by sorting my documents by sequence length, manually creating batches with a reasonable length according to the data, and then manually doing my own training loop. This seems excessive and I’m trying to use Trainer instead.