`datasets.map` calls a function that requires a `transformers.PreTrainedModel` object - unpickable object

I am trying to cache some embeddings computation that requires a transformers.PreTrainedModel model.
I am getting a

Parameter 'fn_kwargs'=
of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.

A dummy reproducible example is the following:

from transformers import CLIPModel, CLIPProcessor
import torch
import numpy as np
from datasets import load_dataset

vision_encoder = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
vision_encoder_processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')

    "Essentially, project the images with the vision encoder."
    with torch.no_grad():
        images = [img.convert("RGB") if img.mode != "RGB" else img for img in examples[TASK_IMAGE_COLUMN_NAME]]
        pixel_values = VISION_ENCODER_PROCESSOR(images=images, return_tensors="pt")["pixel_values"].to(
        image_embeddings = VISION_ENCODER.get_image_features(pixel_values=pixel_values)
    image_embeddings = image_embeddings.cpu().numpy()
    image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, ord=2, axis=1)[:, None]
    examples["vision_encoder_embeddings"] = image_embeddings
    return examples

dataset = load_dataset("cifar10", split="train")
        "DEVICE": "cpu",
        "VISION_ENCODER_PROCESSOR": vision_encoder_processor,
        "VISION_ENCODER": vision_encoder,
        "TASK_IMAGE_COLUMN_NAME": "img"

Diving into why the hashing fails (and thus the expected caching), it looks like that PreTrainedModel fails to be pickled:

from datasets.utils.py_utils import dumps

will return a PicklingError (see datasets/fingerprint.py at 0a067b4b433e1c36a5cabac34cbb630f89dc7eeb · huggingface/datasets · GitHub).

Any idea how i can make the hashing work properly?

Oh and, if that’s useful, here’s my env:

- `datasets` version: 2.7.0
- Platform: Linux-5.4.0-1093-gcp-x86_64-with-glibc2.17
- Python version: 3.8.13
- PyArrow version: 7.0.0
- Pandas version: 1.5.0

Sounds like an issue in transformers: models should be picklable IMO. Could you open an issue there please ?

In the meantime you can choose the new fingerprint of the dataset that is used for caching yourself by passing new_fingerprint= to .map(). Be careful though, if you change your map function or its parameters, make sure to change the new fingerprint as well, or it may reload previously computed results !