Limitations of iterable datasets

Hi everyone,

I have started to setup my research project based on RoBERTa and your run_mlm.py example with trainer, for that purpose I only worked on a subset of my dataset which I load in memory and benchmarked speed for parallel-processing. I am satisfied with the results and I will move to the next steps.

For context, I launch my scripts as
OMP_NUM_THREADS=12 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --standalone --nnodes=1 --nproc_per_node=8 run_mlm.py --dataloader_num_workers 64 --sharded_ddp zero_dp_2 …

I want to work with streaming datasets and I wonder about the limitations and whether I should default to load everything in memory. Here are my questions, thanks for your advice.

_ the map function of iterable datasets doesnt seem to accept the num_proc argument, I wonder whether this will create a bottleneck in my codes or if dataloader_num_workers will allow the iterable dataset to operate in fast multi-processing ?

_ when working in run_mlm.py with the trainer and an iterable dataset, what are the changes to make for parallel-processing please ?
I read this Process but I am not sure if this applies

_ my datasets are stored as .parquet containing input sequences as well as labels/meta-data, one column I would like to implement is a sampling probability in order to over-sample certain training examples
Is there any way to allow this inside an iterable dataset or should I consider duplicating training examples as a pre-processing ?

Thanks !
A

Hi!

_ the map function of iterable datasets doesnt seem to accept the num_proc argument, I wonder whether this will create a bottleneck in my codes or if dataloader_num_workers will allow the iterable dataset to operate in fast multi-processing ?

Adding support for multiple workers (num_workers > 1) to IterableDataset is a work in progress and will be available (most likely) in the next release of datasets. But in your case, for maximum performance, it’s better to use the standard arrow-backed Dataset. Thanks to memory mapping, this version also doesn’t bring everything in memory (only the requested rows/columns).

You can create a dataset from parquet files (the arrow backed version) as follows:

from datasets import load_dataset
dataset  = load_dataset("parquet", data_files=[<list of paths to parquet files>])

_ when working in run_mlm.py with the trainer and an iterable dataset, what are the changes to make for parallel-processing please ?
I read this Process but I am not sure if this applies

You can use the training_args.main_process_first context manager for that (for the arrow backed dataset). You can find an example here.

_ my datasets are stored as .parquet containing input sequences as well as labels/meta-data, one column I would like to implement is a sampling probability in order to over-sample certain training examples
Is there any way to allow this inside an iterable dataset or should I consider duplicating training examples as a pre-processing ?

I’m not sure I understand this question. Could you clarify it a bit more?

Hi Mario and thanks for your reply

I think I am getting set with the two first points, I did not observe code slowing down much when passing iterable datasets or datasets with streaming = False.

About the 3rd point, I think I will go with the option to replicate examples as a pre-processing step, which is the most easy. But to clarify, my question was to handle the case where I have dataset e.g. (x1, x2, …, xN) and I would like to train without seeing each x one time in an epoch. Imagine some samples are e.g. harder than others, or belong to more or less represented clusters, I could then over-sample these if I provide (p1, p2, …, pN) and sample mini-batches according to a certain probability which is e.g. increased for harder or less represented examples.

Still right now I am having a bit of issues working my way through achieving equivalent results with or without streaming datasets.

For others who may see this thread, I had issues to run HF trainer with iterable datasets because at first I haven’t noticed that HF iterable datasets (returned by load_dataset(…, streaming=True)) are not supported by PyTorch and I need to call dataset = dataset.with_format(“torch”) after applying map and before passing to trainer.

About current points I am having issues, in case you may have some hints for me please:

_ training loss curves decrease smoothly with streaming=False but currently with iterable datasets losses do not converge smoothly and even tend to diverge … I am still debugging and have not identify all possible causes, as far as I can tell, differences happening in between streaming=False/True is that for streaming I cannot use the group_by_length training option … apart from that I did not notice any other differences … am I missing some specific things to manually take care of e.g. shuffling, when using iterable datasets with HF trainer ?

_ to evaluate the model, either during training with e.g. evaluation_strategy = epoch and at the end of training with e.g. metrics = trainer.evaluate() ; I read that there are issues as the length of the evaluation/test datasets should be known in advance … are there some regular ways to perform evaluation on iterable datasets, such as callbacks, or should I e.g. use streaming dataset for train and keep in memory eval/test splits ?

our servers have rather large RAM of 1.5TB, so I could actually load my datasets in memory
but I observed that parallel runs on very large datasets (e.g. 500M training examples) take a lot of time to initialise training, i.e. when calling trainer.train
this is actually longer than the time for dataset preprocessing e.g. tokenization
again, this is to take with a pinch of salt and maybe debugging work on my side will solve this, and any hint for me will be very appreciated !

best,
A

Hi @mariosasko

I am actually having an idea why the loss would behave differently in streaming and non-streaming mode, it would be great if you could confirm please.
When I am training with streaming (i.e. iterable dataset), the logger only sees one epoch which is the chosen number of training steps.
Then I am afraid there is no reshuffling of the dataset during training … am I right ?

Question here is what is the best way to fix this please ?
Is there a place where I should configure the length of the dataset, which is known in advance in my case ?
Or should I make a callback every length dataset / batch size to manually shuffle the dataset ?

Hi @adrienchaton ,

I had noticed something similar as well with spiking convergence when training with streamed data and an Iterable Dataset vs a non-streamed non-iterable local dataset.

It may be worth checking out whether using ShufflerIterDataPipe() to shuffle the batches in the Iterable data loader will help to resolve your issue.

For example something like this:

from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe

shuffled_batches = ShufflerIterDataPipe(your_torch_dataset)

train_dataloader = DataLoader(shuffle_batches, shuffle = True, batch_size = 8)

I have been working through it with the Hugging Face team and documenting my results in this thread: Streaming Dataset of Sequence Length 2048 - #7 by loubnabnl

Hope this will help.

Best.

Hi @conceptofmind

Thanks for pointing out your experiments to me and some tools which could help me out.
In the meantime I wrote custom datasets and data collators for HF/Pytorch to use memory mapped arrow tables and tokenize on the fly. This has fixed most of my issues, i.e. good convergence and moderate RAM use.

I will try using HF’s streaming datasets with ShufflerIterDataPipe and see if it behaves well while reducing even more RAM use !

Best.

@mariosasko What’s the status of datasets implementing multi-process loading for IterableDataset? I’m happy to help out if needed, this would really help speed up my model training and allow me to continue to stay within the HF ecosystem

Hi ! We haven’t had a chance to dive into this unfortunately, but contributions are open if you’d like to help :wink: We’d be happy to give you some pointers

Sure! Let me know where’s a good place to start, this would be very helpful for me at least!

Feel free to open an issue on github and explain what kind of multiprocessing would make sense for you when streaming and we can continue from there :slight_smile:

Hi, @mariosako, can you please describe what is the recommended dataset size limit when arrow-backed dataset is no longer effective in terms of used memory and it is better to utilize IterableDataset? As you mentioned, arrow-backed Dataset is quite effective due to its design, so I’m wondering where is the threshold when the IterableDataset starts to make sense. Thank you very much!