Hi,
this inconsistency is due to how PyArrow converts nested sequences to NumPy by default but can be fixed by casting the corresponding column to the ArrayXD
type.
E.g. in your example:
dset = Dataset.from_dict(
{"a": [[[2,1],[2,2]], [[3,1],[3,2]]], "b": [1,1]},
features=Features({"a": Array2D(shape=(2, 2), dtype="int32"), "b": Value("int32")})
)
dset.set_format('torch', columns=['a','b'])
If you want to cast the existing dataset, use map
instead of cast
(cast
fails on special extension types):
dset = Dataset.from_dict({"a": [[[2,1],[2,2]], [[3,1],[3,2]]], "b": [1,1]})
dset = dset.map(lambda batch: batch, batched=True, features=Features({"a": Array2D(shape=(2, 2), dtype="int32"), "b": Value("int32")}))
dset.set_format('torch', columns=['a','b'])