Streaming batched data

Hello, I’m trying to batch a streaming dataset. This is what I have done so far:

coco_train = load_dataset("facebook/pmd", use_auth_token=hf_token, name="coco", cache_dir="/kaggle/working/", streaming=True, split="train")
train_ds = coco_train.map(batched=True, batch_size=32)

however, doing next(iter(train_ds)) only returns one instance (and it is quite slow). Is there something I’m missing here? How do you get a batch of data?

Hi ! map with batch=True processes examples by batch. The output of the map function can be a batch of arbitrary size. Therefore you can transform a batch of batch_size examples to a batch of 1 example:

def group_batch(batch):
    return {k: [v] for k, v in batch.items()}
train_ds = coco_train.map(group_batch, batched=True, batch_size=32)

That worked, thank you :pray:. Although I’m not sure how you got that group_batch function. Is there any docs I could have referred to have figured that out for myself.

Also it seems to be significantly faster to simply download the data and then iterate rather than use the streaming method. Is that generally the case? I’m running the following and it’s running at 6.83s/ it at the moment (for a batch size of 64 on kaggle).

for data in tqdm(train_ds):
    continue

That worked, thank you :pray:. Although I’m not sure how you got that group_batch function. Is there any docs I could have referred to have figured that out for myself.

With map you can change the size of a batch, so you can basically group together a batch of 32 elements into 1. You can find more examples in the docs: Process

Also it seems to be significantly faster to simply download the data and then iterate rather than use the streaming method. Is that generally the case? I’m running the following and it’s running at 6.83s/ it at the moment (for a batch size of 64 on kaggle).

It depends where the data is hosted, some hosts have limited download speed

1 Like