I am getting the following error with the map
function in the laion_400m dataset. I assume, it is because the output of my collate function is 7 (due to a broken image URL), while that of the original batch is 8. I do not care about the other columns in the batch
dictionary. Is there a way to drop them so that I can just focus on these? Here is a full code snippet:
import datasets
CLIP_MODEL = "openai/clip-vit-base-patch32" #"runwayml/stable-diffusion-v1-5"
tokenizer = CLIPTokenizer.from_pretrained(CLIP_MODEL)
feature_extractor = CLIPFeatureExtractor.from_pretrained(CLIP_MODEL)
class CollateFn:
def get_image(self, url):
try:
response = requests.get(url)
return Image.open(io.BytesIO(response.content)).convert("RGB")
except PIL.UnidentifiedImageError:
logger.info(f"Reading error: Could not transform f{url}")
return None
except requests.exceptions.ConnectionError:
logger.info(f"Connection error: Could not transform f{url}")
return None
def __call__(self, batch):
images = [self.get_image(url) for url in batch["url"]]
captions = [caption for caption, image in zip(batch["caption"], images) if image is not None]
images = [image for image in images if image is not None]
tokenized_captions = tokenizer(
captions,
padding="max_length",
truncation=True,
max_length=tokenizer.model_max_length,
return_tensors="pt",
)
image_features = torch.stack([torch.Tensor(feature_extractor(image)["pixel_values"][0]) for image in images])
return {"input_ids": tokenized_captions["input_ids"], "images": image_features}
collate_fn = CollateFn()
laion_ds = datasets.load_dataset("laion/laion400m", split="train", streaming=True)
laion_ds_batched = laion_ds.map(collate_fn, batched=True, batch_size=8)
# error below
tmp = a = next(iter(laion_ds_batched))
Error message:
ValueError Traceback (most recent call last)
Cell In[77], line 1
----> 1 a = next(iter(laion_ds_batched))
File /opt/conda/lib/python3.10/site-packages/datasets/iterable_dataset.py:497, in IterableDataset.__iter__(self)
496 def __iter__(self):
--> 497 for key, example in self._iter():
498 if self.features:
499 # we encode the example for ClassLabel feature types for example
500 encoded_example = self.features.encode_example(example)
File /opt/conda/lib/python3.10/site-packages/datasets/iterable_dataset.py:494, in IterableDataset._iter(self)
492 else:
493 ex_iterable = self._ex_iterable
--> 494 yield from ex_iterable
File /opt/conda/lib/python3.10/site-packages/datasets/iterable_dataset.py:223, in MappedExamplesIterable.__iter__(self)
217 bad_cols = [
218 col
219 for col in transformed_batch
220 if len(transformed_batch[col]) != len(transformed_batch[first_col])
221 ]
222 if bad_cols:
--> 223 raise ValueError(
224 f"Column lengths mismatch: columns {bad_cols} have length {[len(transformed_batch[col]) for col in bad_cols]} while {first_col} has length {len(transformed_batch[first_col])}."
225 )
226 # the new key is the concatenation of the examples keys from the batch
227 new_key = "_".join(str(key) for key in keys)
ValueError: Column lengths mismatch: columns ['input_ids', 'images'] have length [7, 7] while LICENSE has length 8.