I went with the following approach. It would be great if anyone can check it.
from datasets import load_from_disk
class NumpyTransform:
def __init__(self, features, arr_types=None, seq_types=None):
self._feats = features
self._arr_types = arr_types or list(self._feats.keys())
self._seq_types = seq_types or []
def __call__(self, batch):
sample = {}
for key, val in batch.items():
if key in self._arr_types:
val = np.asarray(val, dtype=self._feats[key].dtype)
elif key in self._seq_types:
val = np.asarray(val, dtype=self._feats[key].feature.dtype)
sample[key] = val
return sample
dataset = load_from_disk(data_dir, keep_in_memory=None)
dataset = dataset.with_transform(
NumpyTransform(dataset.features,
arr_types=["example", "label", "coords_label"],
seq_types=["coords_num"])
)