Streaming and creating refactored dataset with shards using Generator

I am trying to stream a dataset (i.e. to disk not to memory), refactor it using a generator and map, and then push it back to the hub. The following methodology acheives this but it is slow, due to the following error: Setting num_proc from 16 back to 1 for the train split to disable multiprocessing as it only contains one shard.

N.B. there was a GitHub issue related here this but I cannot create a solution with gen_kwargs

Here is my minimal reproduce-able code:

from datasets import Dataset, load_dataset
import os

# Load the dataset in streaming mode
streaming_ds = load_dataset("Tevatron/docmatix-ir", split="train", streaming=True)

# Create the map function to modify the dataset
def map_function(example):
    return {
        "query": example["query"],
        "gold_index": example.pop("positive_passages"),
        "negs": example.pop("negative_passages")
    }

# Apply the transformation in a generator function
def generator():
    # Only take the first 100 samples
    for example in streaming_ds:
        yield map_function(example)

# Use from_generator to process and save to disk
ds_on_disk = Dataset.from_generator(generator, num_proc=os.cpu_count())

# Push to Hugging Face Hub
ds_on_disk.push_to_hub("mixedbread-ai/omni-retrieval-dataset-william", "docmatix-ir-processed", private=True)

In the from_generator examples, it says that it should be implemented as follows:

def gen(shards):
    for shard in shards:
        with open(shard) as f:
            for line in f:
                yield {"line": line}
shards = [f"data{i}.txt" for i in range(32)]
ds = Dataset.from_generator(gen, gen_kwargs={"shards": shards})

Therefore I experimented with a new script to use gen_kwargs to take in a series of shards from another dataset.

from datasets import Dataset, load_dataset
import os

def load_shard_dataset(shard_num):
    base_url = "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_{i:06d}.tar"
    # Generate the URL for the specified shard
    url = base_url.format(i=shard_num)
    # Load the specific shard as the dataset
    dataset = load_dataset("webdataset", data_files={"train": [url]}, split="train", streaming=True)
    return dataset

# Create the map function to modify the dataset
def map_function(example):
    return {
        "query": example["json"]["prompt"],
    }

# Apply the transformation in a generator function
def generator(shards):
    for shard in shards:
        ds = load_shard_dataset(shard)
        for example in ds:
                yield map_function(example)

# Use from_generator to process
shards = [shard for shard in range(30)]
print("Generating dataset")
ds_on_disk = Dataset.from_generator(generator, gen_kwargs={"shards": shards}, num_proc=os.cpu_count())

print("Pushing to Hugging Face Hub")
ds_on_disk.push_to_hub("mixedbread-ai/omni-retrieval-dataset-william", "docmatix-ir-processed", private=True)

This works to remove the error (so i assume the num_procs is set to 16), however it is even slower than using one CPU.

1 Like

Hi ! you can also use a DataLoader for this if it’s easier for you:

from torch.utils.data import DataLoader

ds = load_dataset("ts/jackyhate/text-to-image-2M", streaming=True, split="train")
ds = ds.map(map_function)
dataloader = DataLoader(ds, num_workers=16)
ds = Dataset.from_generator(dataloader.__iter__)

(the difference is that the data is written using the main process, while from_generator with num_proc has one writer per subprocess)

1 Like

Thank you very much for your reply! Using a dataloader is a neat solution and utilizes several cores as long as the number of shards of the dataset is more than one.

A note for others, I could not utilize Images directly with the dataloader as the PIL.images were not tensors. So I created a custom collate function for the dataloadr using the solution from here.

1 Like

Great ! btw if you want the dataset to yield examples as torch tensors (including the PIL images), you can set the format to “torch”:

ds = ds.with_format("torch")
2 Likes

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.