Getting correct length via DataLoader and speed

I am attempting to use a streaming dataset and get DataLoader to work as expected and in a fast manner.

Some reproducible code is shown below and I am using datasets version 2.18.0.

My questions are:

  1. Is explicitly setting batch_size in both .map and DataLoader the expected behaviour? As right now it is dropping the first 9 as shown below. The collate function is printing out the expected size of 10.
  2. I don’t think this operation of getting 10 elements is parallelised as it takes 20 seconds. How would I parallelise this?
import io

from typing import Any, Optional, Union

import datasets
from PIL import Image
import requests
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import AutoTokenizer


class Tokenizer:
    def __init__(self, model_name: str, max_len: int) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.max_len = max_len

    def __call__(self, x: Union[str, list[str]]) -> dict[str, torch.LongTensor]:
        return self.tokenizer(
            x, max_length=self.max_len, truncation=True, padding=True, return_tensors="pt"
        )


def _get_image(image_url: str) -> Optional[Image.Image]:
    try:
        response = requests.get(image_url, timeout=1)
        response.raise_for_status()  # Raise HTTPError for bad responses (4xx and 5xx)
        image = Image.open(io.BytesIO(response.content))
        return image
    except (requests.RequestException, IOError):
        return None


class CollateFn:
    def __init__(self, tokenizer: Tokenizer, transform: transforms.Compose):
        self.tokenizer = tokenizer
        self.transform = transform

    def __call__(self, batch: dict[str, Any]) -> dict[str, torch.Tensor]:
        images: list[Optional[Image.Image]] = [_get_image(url) for url in batch["url"]]
        text_batch: list[str] = [
            text for text, image in zip(batch["short_caption"], images) if image is not None
        ]
        images = [image for image in images if image is not None]
        stacked_images = torch.stack([self.transform(image) for image in images])
        tokenized_text = self.tokenizer(text_batch)
        print(stacked_images.shape, tokenized_text["input_ids"].shape)
        return {
            "image": stacked_images,
            **tokenized_text,
        }


tokenizer = Tokenizer("microsoft/xtremedistil-l6-h256-uncased", 128)
transform = transforms.Compose(
    [
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
    ]
)
collate_fn = CollateFn(tokenizer, transform)

dataset = datasets.load_dataset(
    "laion/220k-gpt4vision-captions-from-livis", split="train", streaming=True
)
full_dataset = dataset.shuffle(seed=42, buffer_size=1000).take(100)
train_dataset = full_dataset.take(90)
valid_dataset = full_dataset.skip(90)


train_dl = DataLoader(
    train_dataset.map(
        collate_fn, batched=True, batch_size=10, remove_columns=["url", "short_caption", "caption"]
    )
)
batch = next(iter(train_dl))
print(batch["image"].shape)  # torch.Size([1, 3, 128, 128]) unexpected

train_dl = DataLoader(
    train_dataset.map(
        collate_fn, batched=True, batch_size=10, remove_columns=["url", "short_caption", "caption"]
    ),
    batch_size=10,
    pin_memory=True,
)
batch = next(iter(train_dl))
print(batch["image"].shape)  # torch.Size([10, 3, 128, 128]) as expected

Hi ! batch_size in map() is only to batch inside the map operation - this doesn’t batch the output of the dataset once you iterate on it. If you want your data loader to yield batches, you should pass batch_size to the data loader:

train_dl = DataLoader(
    train_dataset.map(
        collate_fn, batched=True, batch_size=10, remove_columns=["url", "short_caption", "caption"]
    ),
    batch_size=10
)

Thanks for this. Correct me if I’m wrong, does this mean that considering that map is batching up 10 elements, and if data loader then takes the first of those elements, when I set batch_size=10 in the dataloader, then 100 elements would have been processed to get the final 10?

Not Huggingface expert but I think the .map() parallel the preprocessing while the DataLoader controls the number of examples that are loaded and returned as a batch from the dataset during iteration.

To speed up your process, I can see you are using a for loop to read the image from URL where your IO becomes as bottleneck. A simple fix is to use concurrent session. For example,

from concurrent import futures

class CollateFnConcurrent:
    def __init__(self, tokenizer: Tokenizer, transform: transforms.Compose):
        self.tokenizer = tokenizer
        self.transform = transform

    def __call__(self, batch: dict[str, Any]) -> dict[str, torch.Tensor]:
        images = []
        with futures.ThreadPoolExecutor() as executor:
            fs = [executor.submit(_get_image, url) for url in batch['url']]
            for r in futures.as_completed(fs):
                images += [r.result()]
        text_batch: list[str] = [
            text
            for text, image in zip(batch["short_caption"], images)
            if image is not None
        ]
        images = [image for image in images if image is not None]
        stacked_images = torch.stack([self.transform(image) for image in images])
        tokenized_text = self.tokenizer(text_batch)
        print(stacked_images.shape, tokenized_text["input_ids"].shape)
        return {
            "image": stacked_images,
            **tokenized_text,
        }

collate_fn = CollateFnConcurrent(tokenizer, transform)

No, even if you use a batched map(), examples will still be yielded one by one to the DataLoader, which will group them into batches.

So when the DataLoader asks for the first element, the first map() function is applied to the first 10 elements (batched map()), and the first processed element is passed to the DataLoader.