I am attempting to use a streaming dataset and get DataLoader to work as expected and in a fast manner.
Some reproducible code is shown below and I am using datasets version 2.18.0
.
My questions are:
- Is explicitly setting
batch_size
in both.map
andDataLoader
the expected behaviour? As right now it is dropping the first 9 as shown below. The collate function is printing out the expected size of 10. - I don’t think this operation of getting 10 elements is parallelised as it takes 20 seconds. How would I parallelise this?
import io
from typing import Any, Optional, Union
import datasets
from PIL import Image
import requests
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import AutoTokenizer
class Tokenizer:
def __init__(self, model_name: str, max_len: int) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.max_len = max_len
def __call__(self, x: Union[str, list[str]]) -> dict[str, torch.LongTensor]:
return self.tokenizer(
x, max_length=self.max_len, truncation=True, padding=True, return_tensors="pt"
)
def _get_image(image_url: str) -> Optional[Image.Image]:
try:
response = requests.get(image_url, timeout=1)
response.raise_for_status() # Raise HTTPError for bad responses (4xx and 5xx)
image = Image.open(io.BytesIO(response.content))
return image
except (requests.RequestException, IOError):
return None
class CollateFn:
def __init__(self, tokenizer: Tokenizer, transform: transforms.Compose):
self.tokenizer = tokenizer
self.transform = transform
def __call__(self, batch: dict[str, Any]) -> dict[str, torch.Tensor]:
images: list[Optional[Image.Image]] = [_get_image(url) for url in batch["url"]]
text_batch: list[str] = [
text for text, image in zip(batch["short_caption"], images) if image is not None
]
images = [image for image in images if image is not None]
stacked_images = torch.stack([self.transform(image) for image in images])
tokenized_text = self.tokenizer(text_batch)
print(stacked_images.shape, tokenized_text["input_ids"].shape)
return {
"image": stacked_images,
**tokenized_text,
}
tokenizer = Tokenizer("microsoft/xtremedistil-l6-h256-uncased", 128)
transform = transforms.Compose(
[
transforms.Resize((128, 128)),
transforms.ToTensor(),
]
)
collate_fn = CollateFn(tokenizer, transform)
dataset = datasets.load_dataset(
"laion/220k-gpt4vision-captions-from-livis", split="train", streaming=True
)
full_dataset = dataset.shuffle(seed=42, buffer_size=1000).take(100)
train_dataset = full_dataset.take(90)
valid_dataset = full_dataset.skip(90)
train_dl = DataLoader(
train_dataset.map(
collate_fn, batched=True, batch_size=10, remove_columns=["url", "short_caption", "caption"]
)
)
batch = next(iter(train_dl))
print(batch["image"].shape) # torch.Size([1, 3, 128, 128]) unexpected
train_dl = DataLoader(
train_dataset.map(
collate_fn, batched=True, batch_size=10, remove_columns=["url", "short_caption", "caption"]
),
batch_size=10,
pin_memory=True,
)
batch = next(iter(train_dl))
print(batch["image"].shape) # torch.Size([10, 3, 128, 128]) as expected