How to use Dataset with Pytorch Lightning


I’m trying to load the cnn-dailymail dataset to train a model for summarization using pytorch lighntning. To load the dataset with DataLoader I tried to follow the documentation but it doesnt work (the pytorch lightning code I am using does work when the Dataloader isnt using a dataset from huggingface so there shouldnt be a problem in the training procedure).

Here is the code:

def train_dataloader(self):
    train_dataset = load_dataset('cnn_dailymail','3.0.0', split='train')
    train_dataset = e: tokenizer(e['article'],e['highlights'], truncation=True, padding='max_length'), batched=True)
    dataloader = DataLoader(train_dataset, batch_size=self.hparams.train_batch_size)
    return dataloader

I put ‘article’ and ‘highlight’ in tokenizer as these are the 2 columns in the dataset that correspond to target and source.

Here is the stack trace:

Traceback (most recent call last):
  File "", line 346, in <module>
  File "venv/lib/python3.6/site-packages/pytorch_lightning/trainer/", line 48, in wrapped_fn
    result = fn(self, *args, **kwargs)
  File "venv/lib/python3.6/site-packages/pytorch_lightning/trainer/", line 1073, in fit
    results = self.accelerator_backend.train(model)
  File "venv/lib/python3.6/site-packages/pytorch_lightning/accelerators/", line 51, in train
    results = self.trainer.run_pretrain_routine(model)
  File "venv/lib/python3.6/site-packages/pytorch_lightning/trainer/", line 1224, in run_pretrain_routine
    self._run_sanity_check(ref_model, model)
  File "venv/lib/python3.6/site-packages/pytorch_lightning/trainer/", line 1257, in _run_sanity_check
    eval_results = self._evaluate(model, self.val_dataloaders, max_batches, False)
  File "venv/lib/python3.6/site-packages/pytorch_lightning/trainer/", line 305, in _evaluate
    for batch_idx, batch in enumerate(dataloader):
  File "venv/lib/python3.6/site-packages/torch/utils/data/", line 517, in __next__
    data = self._next_data()
  File "venv/lib/python3.6/site-packages/torch/utils/data/", line 557, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "venv/lib/python3.6/site-packages/torch/utils/data/_utils/", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "venv/lib/python3.6/site-packages/torch/utils/data/_utils/", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "venv/lib/python3.6/site-packages/datasets/", line 1087, in __getitem__
  File "venv/lib/python3.6/site-packages/datasets/", line 1074, in _getitem
  File "venv/lib/python3.6/site-packages/datasets/", line 890, in _convert_outputs
    v = map_nested(command, v, **map_nested_kwargs)
  File "venv/lib/python3.6/site-packages/datasets/utils/", line 225, in map_nested
    return function(data_struct)
  File "venv/lib/python3.6/site-packages/datasets/", line 851, in command
    return torch.tensor(x, **format_kwargs)
TypeError: new(): invalid data type 'str'

Any ideas on how to make it work ?

Thanks !

I think you also need to specify which columns you’d like to keep when doing .set_format(type='torch'). If you don’t do this, then the text columns are still part of the dataset, and converting strings to PyTorch tensors causes an error.

So I think you just need to update that line to:

train_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])

1 Like