So I’m bit unclear about how set_format converts data into torch tensors. My understanding is that if I use set_format to my dataset that has lists as values, it would conver them to tensors. I do get normal torch tensors for 1d and 2d lists, but when I have 3d lists, somehow it returns list of tensors.
Here is a toy example:
from datasets import Dataset
ex1 = {'a':[[1,1],[1,2]], 'b':[1,1]}
ex2 = {'a':[[[2,1],[2,2]], [[3,1],[3,2]]], 'b':[1,1]}
d1 = Dataset.from_dict(ex1)
d1.set_format('torch', columns=['a','b'])
d2 = Dataset.from_dict(ex2)
d2.set_format('torch', columns=['a','b'])
print(d1[:2])
print(d2[:2])
and the output is:
{'a': tensor([[1, 1],
[1, 2]]), 'b': tensor([1, 1])}
{'a': [[tensor([2, 1]), tensor([2, 2])], [tensor([3, 1]), tensor([3, 2])]], 'b': tensor([1, 1])}
As you can see, for d2 dataset with 3d lists, I’m getting list of tensors, instead of getting just straight tensor arrays as were the cases for 1d and 2d lists in d1 dataset. Why is it returning list of tensors for 3d lists? Is there a way to get straight tensor arrays for such lists?
Really need a good clarification on it. Thank you.