How to apply .map() function and keep it as an iterator for a Hugging Face Dataset, in Streaming Mode without loading it to memroy?

I’m currently working with the Hugging Face datasets library and need to apply transformations to multiple datasets (such as ds_khan and ds_mathematica) using the .map() function, but in a way that mimics streaming (i.e., without loading the entire dataset into memory). I am particularly interested in interleaving these transformed datasets while keeping the data processing as lazy as possible, similar to streaming=True.

Here is the relevant part of my current code:

from datasets import load_dataset, interleave_datasets

def get_hf_khan_ds(path_2_ds: str, split: str = 'train'):
    path_2_ds = os.path.expanduser(path_2_ds)
    dataset = load_dataset('json', data_files=[path_2_ds], split=split, streaming=True)
    problem_as_text = lambda example: {'text': example['problem']}
    return dataset.map(problem_as_text, remove_columns=dataset.column_names)

def main():
    ds_khan = get_hf_khan_ds('~/gold-ai-olympiad/data/amps/khan/train.jsonl')
    ds_mathematica = get_hf_khan_ds('~/gold-ai-olympiad/data/amps/mathematica/train.jsonl')
    interleaved_datasets = interleave_datasets([ds_khan, ds_mathematica], probabilities=[0.5, 0.5])
    for sample in interleaved_datasets.take(10):
        print(sample)

if __name__ == '__main__':
    main()

This setup is intended to process and interleave the datasets without loading them fully into memory. However, I am not sure if this approach is correctly implementing the streaming and lazy evaluation as I intend.

Questions:

  1. Does this code correctly apply transformations in a streaming or iterator-style fashion?
  2. If not, how can I modify it to ensure that each dataset is only processed as needed, without preloading the entire content?
  3. Is there a more efficient way to interleave these datasets while maintaining a streaming approach?

Any suggestions or insights on how to effectively use .map() with streaming=True for interleaving datasets would be greatly appreciated (note I do have the data set in disk but eventually I want to work with HF datasets).

output:

  table = cls._concat_blocks(blocks, axis=0)
Map: 100%|█████████████████████████████████████████████████████████████████████████████████████| 103059/103059 [00:06<00:00, 16218.25 examples/s]
Map: 100%|█████████████████████████████████████████████████████████████████████████████████████| 103059/103059 [00:07<00:00, 13633.64 examples/s]
Map: 100%|█████████████████████████████████████████████████████████████████████████████████████| 103059/103059 [00:08<00:00, 12444.64 examples/s]
/lfs/ampere1/0/brando9/miniconda/envs/gold_ai_olympiad/lib/python3.11/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by promote_options='default'.
  table = cls._concat_blocks(blocks, axis=0)
Downloading data files: 100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 3792.32it/s]
Extracting data files: 100%|██████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 479.84it/s]
Generating train split: 20 examples [00:00, 3628.29 examples/s]
Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 2792.11 examples/s]
Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 3525.66 examples/s]
Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 3415.97 examples/s]
/lfs/ampere1/0/brando9/miniconda/envs/gold_ai_olympiad/lib/python3.11/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by promote_options='default'.
  table = cls._concat_blocks(blocks, axis=0)
probabilities=[0.3333333333333333, 0.3333333333333333, 0.3333333333333333]
/lfs/ampere1/0/brando9/miniconda/envs/gold_ai_olympiad/lib/python3.11/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by promote_options='default'.
  table = cls._concat_blocks(blocks, axis=0)
Done! Time: 50.09 sec, 0.83 min, 0.01 hr

I didn’t expect the output to show the entire data set…which confused me. Also it’s taking too long to load which is bad.

ref: machine learning - How to apply .map() function and keep it as an iterator for a Hugging Face Dataset, in Streaming Mode without loading it to memory? - Stack Overflow

@brando I think you’re using the map() and interleave_datasets() functions correctly. At the very least, the code you’ve provided seems to work fine with the HuggingFace datasets. I tried running a slightly modified version of your code so that it works with datasets uploaded on HuggingFace instead:

from datasets import load_dataset, interleave_datasets

def get_hf_khan_ds(hf_dataset: str, split: str = 'train'):
    dataset = load_dataset(hf_dataset, split=split, streaming=True)
    column_names = dataset.column_names
    problem_as_text = lambda example: {'text': example['question'], 'source' : hf_dataset}
    return dataset.map(problem_as_text, remove_columns=column_names)

def main():
    ds_1 = get_hf_khan_ds('nvidia/OpenMathInstruct-1')
    ds_2 = get_hf_khan_ds('nvidia/OpenMath-GSM8K-masked')
    interleaved_datasets = interleave_datasets([ds_1, ds_2], probabilities=[0.5, 0.5])
    for i, sample in enumerate(interleaved_datasets.take(10)):
        print(f">>> Sample {i + 1}\ntext : {sample['text']}\nsource : {sample['source']}\n")

if __name__ == '__main__':
    main()

The behavior is as you described. In terms of efficiency, it may be better to implement a map function that uses the batched approach.

def get_hf_khan_ds_batched(hf_dataset: str, split: str = 'train'):
    dataset = load_dataset(hf_dataset, split=split, streaming=True)
    problem_as_text = lambda example: {'text': [question for question in example['question']], 'source' : [hf_dataset] * len(example['question'])}
    return dataset.map(problem_as_text, remove_columns=dataset.column_names, batched = True, batch_size = 10)