Hi,
I’m trying to load the cnn-dailymail dataset to train a model for summarization using pytorch lighntning. To load the dataset with DataLoader I tried to follow the documentation but it doesnt work (the pytorch lightning code I am using does work when the Dataloader isnt using a dataset from huggingface so there shouldnt be a problem in the training procedure).
Here is the code:
def train_dataloader(self):
train_dataset = load_dataset('cnn_dailymail','3.0.0', split='train')
train_dataset = train_dataset.map(lambda e: tokenizer(e['article'],e['highlights'], truncation=True, padding='max_length'), batched=True)
train_dataset.set_format(type='torch')
dataloader = DataLoader(train_dataset, batch_size=self.hparams.train_batch_size)
return dataloader
I put ‘article’ and ‘highlight’ in tokenizer as these are the 2 columns in the dataset that correspond to target and source.
Here is the stack trace:
Traceback (most recent call last):
File "MBART.py", line 346, in <module>
trainer.fit(model)
File "venv/lib/python3.6/site-packages/pytorch_lightning/trainer/states.py", line 48, in wrapped_fn
result = fn(self, *args, **kwargs)
File "venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1073, in fit
results = self.accelerator_backend.train(model)
File "venv/lib/python3.6/site-packages/pytorch_lightning/accelerators/gpu_backend.py", line 51, in train
results = self.trainer.run_pretrain_routine(model)
File "venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1224, in run_pretrain_routine
self._run_sanity_check(ref_model, model)
File "venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1257, in _run_sanity_check
eval_results = self._evaluate(model, self.val_dataloaders, max_batches, False)
File "venv/lib/python3.6/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 305, in _evaluate
for batch_idx, batch in enumerate(dataloader):
File "venv/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 517, in __next__
data = self._next_data()
File "venv/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 557, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "venv/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "venv/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "venv/lib/python3.6/site-packages/datasets/arrow_dataset.py", line 1087, in __getitem__
format_kwargs=self._format_kwargs,
File "venv/lib/python3.6/site-packages/datasets/arrow_dataset.py", line 1074, in _getitem
format_kwargs=format_kwargs,
File "venv/lib/python3.6/site-packages/datasets/arrow_dataset.py", line 890, in _convert_outputs
v = map_nested(command, v, **map_nested_kwargs)
File "venv/lib/python3.6/site-packages/datasets/utils/py_utils.py", line 225, in map_nested
return function(data_struct)
File "venv/lib/python3.6/site-packages/datasets/arrow_dataset.py", line 851, in command
return torch.tensor(x, **format_kwargs)
TypeError: new(): invalid data type 'str'
Any ideas on how to make it work ?
Thanks !