Hey!
I am trying to cache some embeddings computation that requires a transformers.PreTrainedModel
model.
I am getting a
Parameter 'fn_kwargs'=
[...]
of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.
A dummy reproducible example is the following:
from transformers import CLIPModel, CLIPProcessor
import torch
import numpy as np
from datasets import load_dataset
vision_encoder = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
vision_encoder.eval()
vision_encoder_processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
def vision_encoder_embeddings(examples, DEVICE, VISION_ENCODER_PROCESSOR, VISION_ENCODER, TASK_IMAGE_COLUMN_NAME):
"Essentially, project the images with the vision encoder."
with torch.no_grad():
images = [img.convert("RGB") if img.mode != "RGB" else img for img in examples[TASK_IMAGE_COLUMN_NAME]]
pixel_values = VISION_ENCODER_PROCESSOR(images=images, return_tensors="pt")["pixel_values"].to(
DEVICE
)
image_embeddings = VISION_ENCODER.get_image_features(pixel_values=pixel_values)
image_embeddings = image_embeddings.cpu().numpy()
image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, ord=2, axis=1)[:, None]
examples["vision_encoder_embeddings"] = image_embeddings
return examples
dataset = load_dataset("cifar10", split="train")
dataset.map(
vision_encoder_embeddings,
batched=True,
batch_size=4,
fn_kwargs={
"DEVICE": "cpu",
"VISION_ENCODER_PROCESSOR": vision_encoder_processor,
"VISION_ENCODER": vision_encoder,
"TASK_IMAGE_COLUMN_NAME": "img"
}
)
Diving into why the hashing fails (and thus the expected caching), it looks like that PreTrainedModel
fails to be pickled:
from datasets.utils.py_utils import dumps
dumps(vision_encoder)
will return a PicklingError
(see datasets/fingerprint.py at 0a067b4b433e1c36a5cabac34cbb630f89dc7eeb · huggingface/datasets · GitHub).
Any idea how i can make the hashing work properly?