Create dataset consisting of numpy arrays, Sequence or ArrayND?

I went with the following approach. It would be great if anyone can check it.

from datasets import load_from_disk

class NumpyTransform:
    def __init__(self, features, arr_types=None, seq_types=None):
        self._feats = features
        self._arr_types = arr_types or list(self._feats.keys())
        self._seq_types = seq_types or []

    def __call__(self, batch):
        sample = {}
        for key, val in batch.items():
            if key in self._arr_types:
                val = np.asarray(val, dtype=self._feats[key].dtype)
            elif key in self._seq_types:
                val = np.asarray(val, dtype=self._feats[key].feature.dtype)
            sample[key] = val
        return sample

dataset = load_from_disk(data_dir, keep_in_memory=None)
dataset = dataset.with_transform(
    NumpyTransform(dataset.features,
                   arr_types=["example", "label", "coords_label"],
                   seq_types=["coords_num"])
)
1 Like