Hi all, I’m using a streaming dataset, and then passing that into a normal PyTorch DataLoader. I have a simple map function that pads various sequences in the dataset.
def pad_map(x):
x['words1'] = pad(x['words1'])
x['words2'] = pad(x['words2'])
x['tags1'] = pad(x['tags1'])
x['tags2'] = pad(x['tags2'])
return x
when I pass the dataset into the PyTorch DataLoader (for example with a batch_size of 4), I get this unexpected result.
batch = {'words1': [tensor([ 44, 3399, 419, 3256]),
tensor([312, 222, 667, 290]),
tensor([120, 883, 99, 101]),
...
tensor([0, 0, 0, 0]),
tensor([0, 0, 0, 0]),
tensor([0, 0, 0, 0]),
tensor([0, 0, 0, 0])],
'label': tensor([1, 0, 0, 0])}
The only way I see to fix this is to do torch.stack(batch['words1']).T
to get the proper format.
Is there any way for the map function to produce an output that like the following ?
{'words1': tensor of shape (batch_size, seq_len), ..., 'labels': tensor of shape (batch_size, )}
Thanks for your help!