Make text data continuous from DatasetDict


I have a question about processing datasets, and hope to find some opionions here.
Given a dataset,

    train: Dataset({
        features: ['text'],
        num_rows: 659

which contains text with various sizes in each row. I want to concatenate all data in the rows and then divide them evenly such that in the end every row has the same length.

Is there a “built-in” way to do this?


Hello Patrick,

This is one option to handle your use case.

  1. Select the tokenizer you wish to use with the dataset. For example the pre-trained GPT2 one.
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

  2. Load a dataset of your choosing.
    load_data = load_dataset("wikitext", "wikitext-2-v1", split="train")

  3. Concatenate tokenized input examples together and then split them into sequences of exactly 512 tokens. The last batch will likely be less than 512 so you will need to filter or pad it. Sequence length is arbitrary and can be chosen depending on the application. Make sure that the tokenizer you initially select is not limited to a specific sequence length otherwise, you may get a warning. You can check the configuration file provided when downloading it.

def tokenize(examples):
    seq_length = 512
    examples = tokenizer(examples["text"])
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    if total_length >= seq_length:
        total_length = (total_length // seq_length) * seq_length
    result = {
        k: [t[i : i + seq_length] for i in range(0, total_length, seq_length)]
        for k, t in concatenated_examples.items()
    return result
  1. Map the tokenizer function to the loaded dataset. Remove columns to get input_ids, attention_mask, etc.
    tokenized_dataset =, batched=True, remove_columns= ['text'])

  2. Filter out any examples which do not have a sequence length of 512. You can also pad them or set drop_last=True in the PyTorch DataLoader.
    filtered_dataset = tokenized_dataset.filter(lambda x: x["input_ids"] > 512)

You can check the length of sequences in the dataset to ensure.

Hope this can help!