Using set_transform on a dataset leads to an exception

I’m trying to do the exact same thing mentioned in the documentation:

from datasets import Dataset

audio_dataset_amr = Dataset.from_dict({"audio": ["audio_samples/audio.amr"]})

def decode_audio_with_pydub(batch):
    return batch

audio_dataset_amr.set_transform(decode_audio_with_pydub)
audio_dataset_amr.save_to_disk(f"./transformed_dataset")

But it fails with the following error:

Exception has occurred: TypeError
Object of type function is not JSON serializable
The format kwargs must be JSON serializable, but key 'transform' isn't.
TypeError: Object of type function is not JSON serializable

During handling of the above exception, another exception occurred:

  File "/home/mehran/tmp/locale_classifier/transform_dataset_test.py", line 59, in <module>
    audio_dataset_amr.save_to_disk(f"./transformed_dataset")
TypeError: Object of type function is not JSON serializable
The format kwargs must be JSON serializable, but key 'transform' isn't.

Is this a bug or am I doing something wrong?

BTW, if I comment out the set_transform, everything works just fine.

For this to work, we would have to pickle a custom transform, which means the transform and the objects it references need to be serializable. Also, deserializing these bytes would make load_from_disk unsafe, so I’m not sure this is a good idea.

EDIT:

I’ve opened a GH issue here to further discuss this matter.