Returns list of tensors instead of tensors with set_format in datasets

So I’m having issues with datasets’ set_format function, where I expect to get straight tensor array, instead getting list of tensors. I get normal torch tensors for 1d and 2d lists, but when I pass 3d lists, somehow it returns list of tensors.
Here is an example:

from datasets import Dataset

ex1 = {'a':[[1,1],[1,2]], 'b':[1,1]}
ex2 = {'a':[[[2,1],[2,2]], [[3,1],[3,2]]], 'b':[1,1]}

d1 = Dataset.from_dict(ex1)
d1.set_format('torch', columns=['a','b'])
d2 = Dataset.from_dict(ex2)
d2.set_format('torch', columns=['a','b'])

print(d1[:2])
print(d2[:2])

and the output is:

{'a': tensor([[1, 1],
        [1, 2]]), 'b': tensor([1, 1])}
{'a': [[tensor([2, 1]), tensor([2, 2])], [tensor([3, 1]), tensor([3, 2])]], 'b': tensor([1, 1])}

I was expecting to get straight 3d tensor for d2. Why is it returning list?
Would like to get any clarification on it. Thank you.

1 Like

I encountered exactly the same problem, if the array has 3 dimensions set_format fails and I obtain list of 2d tensors instead of one 3d tensor. To solve this problem you have to simply cast column to Array3D type with, e.g., Dataset.cast_column function