Streaming datasets and batched mapping

I’m exploring using streaming datasets with a function that preprocesses the text, tokenizes it into training samples, and then applies some noise to the input_ids (à la BART pretraining). It seems to be working really well, and saves a huge amount of disk space compared to downloading a dataset like OSCAR locally.

Since a lot of the examples in OSCAR are much longer than my model’s max size, I’ve been truncating each example to the final whitespace at the end of the first model-size chunk, and throwing a way a ton of data. Not the end of the world, but it feels… wasteful.

I took a look at how MappedExamplesIterable handles batching, and I had a realization. Since __iter__ fetches a batch from the dataset and then just yields each output of the mapped function, there’s no reason the number of processed results needs to be the same as the batch size, right?

The preprocessing function could split the longer examples into smaller chunks, and batch could yield any number of processed examples. It looks like the only thing batch_size is used for is pulling chunks of data from the cloud, and nothing downstream will care how many examples are returned, because they’re yielded one at a time. So a batch in MappedExamplesIterable with batch_size=100 could have 100, or 110, or 3000 or however many examples.

The only downside I see is not knowing how many total examples I’ll have to work with. But with a streaming dataset, I have to train with a predefined max_steps anyway, so that doesn’t seem so bad.

Am I understanding this correctly?

Two things I’ve realized already:

  1. The tokenizing step should pad each example to the max_length, since examples from different dataset batches could end up in the same training batch.

  2. Resuming training spends a long time iterating up to the first batch, since the mapped function is called even if the data is just getting skipped, and it’s a non-trivial amount of work. It’s much faster to manually skip batches before the map is applied. But it’s hard to know how far ahead to skip, since there’s no record of which training examples came from which dataset examples. So I split my mapped function in two. The first splits the long data into manageable chunks. I apply that map, manually skip to the checkpoint, and then apply a second map, which has the preprocessing/tokenizing/noising code.

But other than those two things, my initial tests with this seem to be working really well so far.

Yes exactly ! A batched function can return a different number of samples than in the input :slight_smile: This can be used to chunk each sample into several samples.

Yes indeed. An alternative could be to do the padding using collate_fn in pytorch.

That makes sense !

1 Like

Thanks for the confirmation, @lhoestq! This is very cool.

This style of batched fetching is only used by streaming datasets, right? I’d need to roll my own wrapper to do the same on-the-fly chunking on a local dataset loaded from disk?

And I know it’s not your area, but as far as you know, there’s no way to add/change a dataset’s map functions inside the HF Trainer’s train process, is there? Right now I’ve rigged up resuming from a checkpoint by manually advancing the dataset pre-mapping as I described, reducing the max_steps by the checkpoint number, and setting ignore_date_skip=True in the trainer. It works, but is slightly clunky.

This style of batched fetching is only used by streaming datasets, right? I’d need to roll my own wrapper to do the same on-the-fly chunking on a local dataset loaded from disk?

Yes indeed, though you can stream the data from your disk as well if you want.

A dataset in non streaming mode needs to have a fixed number of samples known in advance as well as a mapping index <-> sample. That’s why chunking on-the-fly is not allowed for non-streaming mode.

And I know it’s not your area, but as far as you know, there’s no way to add/change a dataset’s map functions inside the HF Trainer’s train process, is there?

Indeed there’s no such mechanism afaik, I think you would have to subclass the trainer and implement this logic yourself.

I did not realize this! That should solve the next problem I was going to run into.

Thanks!