Getting list of tensors instead of tensor array after using set_format

So I’m bit unclear about how set_format converts data into torch tensors. My understanding is that if I use set_format to my dataset that has lists as values, it would conver them to tensors. I do get normal torch tensors for 1d and 2d lists, but when I have 3d lists, somehow it returns list of tensors.
Here is a toy example:

from datasets import Dataset

ex1 = {'a':[[1,1],[1,2]], 'b':[1,1]}
ex2 = {'a':[[[2,1],[2,2]], [[3,1],[3,2]]], 'b':[1,1]}

d1 = Dataset.from_dict(ex1)
d1.set_format('torch', columns=['a','b'])
d2 = Dataset.from_dict(ex2)
d2.set_format('torch', columns=['a','b'])

print(d1[:2])
print(d2[:2])

and the output is:

{'a': tensor([[1, 1],
        [1, 2]]), 'b': tensor([1, 1])}
{'a': [[tensor([2, 1]), tensor([2, 2])], [tensor([3, 1]), tensor([3, 2])]], 'b': tensor([1, 1])}

As you can see, for d2 dataset with 3d lists, I’m getting list of tensors, instead of getting just straight tensor arrays as were the cases for 1d and 2d lists in d1 dataset. Why is it returning list of tensors for 3d lists? Is there a way to get straight tensor arrays for such lists?
Really need a good clarification on it. Thank you.

Hi,

this inconsistency is due to how PyArrow converts nested sequences to NumPy by default but can be fixed by casting the corresponding column to the ArrayXD type.

E.g. in your example:

dset = Dataset.from_dict(
    {"a": [[[2,1],[2,2]], [[3,1],[3,2]]], "b": [1,1]}, 
    features=Features({"a": Array2D(shape=(2, 2), dtype="int32"), "b": Value("int32")})
)
dset.set_format('torch', columns=['a','b'])

If you want to cast the existing dataset, use map instead of cast (cast fails on special extension types):

dset = Dataset.from_dict({"a": [[[2,1],[2,2]], [[3,1],[3,2]]], "b": [1,1]})
dset = dset.map(lambda batch: batch, batched=True, features=Features({"a": Array2D(shape=(2, 2), dtype="int32"), "b": Value("int32")}))
dset.set_format('torch', columns=['a','b'])