The map method’s fn_kwargs is not working with streaming=True, but it works otherwise. Here is the my implementation
USER_AGENT = get_datasets_user_agent()
def fetch_single_image(image_url, timeout=None, retries=0):
for _ in range(retries + 1):
try:
request = urllib.request.Request(
image_url,
data=None,
headers={"user-agent": USER_AGENT},
)
with urllib.request.urlopen(request, timeout=timeout) as req:
image = PIL.Image.open(io.BytesIO(req.read()))
break
except Exception:
image = None
return image
def fetch_images(batch, num_threads, timeout=None, retries=0):
fetch_single_image_with_args = partial(fetch_single_image, timeout=timeout, retries=retries)
with ThreadPoolExecutor(max_workers=num_threads) as executor:
batch["image"] = list(executor.map(fetch_single_image_with_args, batch["image_url"]))
return batch
num_threads = 20
dset = load_dataset("conceptual_captions", split='train', streaming=True)
# dset = dset.map(fetch_images, batched=True, batch_size=32)
dset = dset.map(fetch_images, batched=True, batch_size=32, fn_kwargs={"num_threads": num_threads})
print(next(iter(dset)))
Here is the error