How to use Dataset with Pytorch Lightning

I think you also need to specify which columns you’d like to keep when doing .set_format(type='torch'). If you don’t do this, then the text columns are still part of the dataset, and converting strings to PyTorch tensors causes an error.

So I think you just need to update that line to:

train_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])

1 Like