Not Huggingface expert but I think the .map() parallel the preprocessing while the DataLoader controls the number of examples that are loaded and returned as a batch from the dataset during iteration.
To speed up your process, I can see you are using a for loop to read the image from URL where your IO becomes as bottleneck. A simple fix is to use concurrent session. For example,
from concurrent import futures
class CollateFnConcurrent:
def __init__(self, tokenizer: Tokenizer, transform: transforms.Compose):
self.tokenizer = tokenizer
self.transform = transform
def __call__(self, batch: dict[str, Any]) -> dict[str, torch.Tensor]:
images = []
with futures.ThreadPoolExecutor() as executor:
fs = [executor.submit(_get_image, url) for url in batch['url']]
for r in futures.as_completed(fs):
images += [r.result()]
text_batch: list[str] = [
text
for text, image in zip(batch["short_caption"], images)
if image is not None
]
images = [image for image in images if image is not None]
stacked_images = torch.stack([self.transform(image) for image in images])
tokenized_text = self.tokenizer(text_batch)
print(stacked_images.shape, tokenized_text["input_ids"].shape)
return {
"image": stacked_images,
**tokenized_text,
}
collate_fn = CollateFnConcurrent(tokenizer, transform)