Odd dataset.map() behavior with PyTorch dataloader

Hi all, I’m using a streaming dataset, and then passing that into a normal PyTorch DataLoader. I have a simple map function that pads various sequences in the dataset.

def pad_map(x):
    x['words1'] = pad(x['words1'])
    x['words2'] = pad(x['words2'])
    x['tags1'] = pad(x['tags1'])
    x['tags2'] = pad(x['tags2'])
    return x

when I pass the dataset into the PyTorch DataLoader (for example with a batch_size of 4), I get this unexpected result.

batch = {'words1': [tensor([  44, 3399,  419, 3256]),
  tensor([312, 222, 667, 290]),
  tensor([120, 883,  99, 101]),
...
  tensor([0, 0, 0, 0]),
  tensor([0, 0, 0, 0]),
  tensor([0, 0, 0, 0]),
  tensor([0, 0, 0, 0])],
 'label': tensor([1, 0, 0, 0])}

The only way I see to fix this is to do torch.stack(batch['words1']).T to get the proper format.
Is there any way for the map function to produce an output that like the following ?
{'words1': tensor of shape (batch_size, seq_len), ..., 'labels': tensor of shape (batch_size, )}
Thanks for your help!

I figured it out, creating a tensor in the map function stopped it from batching it strangely. So like

def pad_map(x):
    x['words1'] = torch.tensor(pad(x['words1']))
    x['words2'] = torch.tensor(pad(x['words2']))
    x['tags1'] = torch.tensor(pad(x['tags1']))
    x['tags2'] = torch.tensor(pad(x['tags2']))
    return x

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.