Length error using `map` with datasets

I am getting the following error with the map function in the laion_400m dataset. I assume, it is because the output of my collate function is 7 (due to a broken image URL), while that of the original batch is 8. I do not care about the other columns in the batch dictionary. Is there a way to drop them so that I can just focus on these? Here is a full code snippet:

import datasets

CLIP_MODEL = "openai/clip-vit-base-patch32" #"runwayml/stable-diffusion-v1-5"

tokenizer = CLIPTokenizer.from_pretrained(CLIP_MODEL)
feature_extractor = CLIPFeatureExtractor.from_pretrained(CLIP_MODEL)


class CollateFn:        
    def get_image(self, url):
        try:
            response = requests.get(url)
            return Image.open(io.BytesIO(response.content)).convert("RGB")
        except PIL.UnidentifiedImageError:
            logger.info(f"Reading error: Could not transform f{url}")
            return None
        except requests.exceptions.ConnectionError:
            logger.info(f"Connection error: Could not transform f{url}")
            return None

    def __call__(self, batch):
        images = [self.get_image(url) for url in batch["url"]]
        captions = [caption for caption, image in zip(batch["caption"], images) if image is not None]
        images = [image for image in images if image is not None]
        
        tokenized_captions = tokenizer(
            captions,
            padding="max_length",
            truncation=True,
            max_length=tokenizer.model_max_length,
            return_tensors="pt",
        )
        
        image_features = torch.stack([torch.Tensor(feature_extractor(image)["pixel_values"][0]) for image in images])
        
        return {"input_ids": tokenized_captions["input_ids"], "images": image_features}

collate_fn = CollateFn()
laion_ds = datasets.load_dataset("laion/laion400m", split="train", streaming=True)
laion_ds_batched = laion_ds.map(collate_fn, batched=True, batch_size=8)

# error below
tmp = a = next(iter(laion_ds_batched))

Error message:

ValueError                                Traceback (most recent call last)
Cell In[77], line 1
----> 1 a = next(iter(laion_ds_batched))

File /opt/conda/lib/python3.10/site-packages/datasets/iterable_dataset.py:497, in IterableDataset.__iter__(self)
    496 def __iter__(self):
--> 497     for key, example in self._iter():
    498         if self.features:
    499             # we encode the example for ClassLabel feature types for example
    500             encoded_example = self.features.encode_example(example)

File /opt/conda/lib/python3.10/site-packages/datasets/iterable_dataset.py:494, in IterableDataset._iter(self)
    492 else:
    493     ex_iterable = self._ex_iterable
--> 494 yield from ex_iterable

File /opt/conda/lib/python3.10/site-packages/datasets/iterable_dataset.py:223, in MappedExamplesIterable.__iter__(self)
    217     bad_cols = [
    218         col
    219         for col in transformed_batch
    220         if len(transformed_batch[col]) != len(transformed_batch[first_col])
    221     ]
    222     if bad_cols:
--> 223         raise ValueError(
    224             f"Column lengths mismatch: columns {bad_cols} have length {[len(transformed_batch[col]) for col in bad_cols]} while {first_col} has length {len(transformed_batch[first_col])}."
    225         )
    226 # the new key is the concatenation of the examples keys from the batch
    227 new_key = "_".join(str(key) for key in keys)

ValueError: Column lengths mismatch: columns ['input_ids', 'images'] have length [7, 7] while LICENSE has length 8.

Hi ! Yes you can remove the other columns with:

laion_ds_batched = laion_ds.map(collate_fn, batched=True, batch_size=8, remove_columns=laion_ds.column_names)
1 Like

Hey, almost there. Since laion_ds was an IterableDataset class I could only do the following: laion_ds_batched = laion_ds.map(collate_fn, batched=True, batch_size=8, remove_columns=next(iter(laion_ds)).keys()). However, when I do x = next(iter(laion_ds_batched)) the returned dictionary only returns one element out the possible 8. The stranger thing is when I set a breakpoint within the CollateFn class, it is batching more than one element properly.

A minimal working example can be found here. Is this possibly a bug? laion_hf_dataset | Kaggle