Hi
I have a Dataset that contains both torch Tensor columns and normal python (text, float) data columns.
As soon as I change one of the columns to type=‘torch’ I can only iterate over this specific column not all of the columns.
Example:
import torch
from datasets import Dataset
ds_dict = {'text':['text1','text2'],'tens':torch.Tensor([1,2])}
ds = Dataset.from_dict(ds_dict)
for row in ds:
print(row)
Output:
{'text': 'text1', 'tens': 1.0}
{'text': 'text2', 'tens': 2.0}
Now changing ‘tens’ to torch:
ds.set_format(type='torch',columns=['tens'])
for row in ds:
print(row)
Ouptut:
{'tens': tensor(1.)}
{'tens': tensor(2.)}
The iteration is now missing column ‘text’
But the column is still in the dataset:
print(ds)
Output:
Dataset({
features: ['text', 'tens'],
num_rows: 2
})