So I’m having issues with datasets’ set_format function, where I expect to get straight tensor array, instead getting list of tensors. I get normal torch tensors for 1d and 2d lists, but when I pass 3d lists, somehow it returns list of tensors.
Here is an example:
from datasets import Dataset
ex1 = {'a':[[1,1],[1,2]], 'b':[1,1]}
ex2 = {'a':[[[2,1],[2,2]], [[3,1],[3,2]]], 'b':[1,1]}
d1 = Dataset.from_dict(ex1)
d1.set_format('torch', columns=['a','b'])
d2 = Dataset.from_dict(ex2)
d2.set_format('torch', columns=['a','b'])
print(d1[:2])
print(d2[:2])
and the output is:
{'a': tensor([[1, 1],
[1, 2]]), 'b': tensor([1, 1])}
{'a': [[tensor([2, 1]), tensor([2, 2])], [tensor([3, 1]), tensor([3, 2])]], 'b': tensor([1, 1])}
I was expecting to get straight 3d tensor for d2. Why is it returning list?
Would like to get any clarification on it. Thank you.