Using IterableDataset with Trainer - `IterableDataset' has no len()

Hi everyone,

I have a large-ish dataset that I am loading with something like:

dataset_train = load_dataset(
    'json',
    data_files=...,
    split='train',
    streaming=True,
)

def preprocess_function(examples):
    return tokenizer(examples["text"], examples["text_pair"], truncation=True)

tokenized_dataset = dataset_train.map(preprocess_function, batched=True, batch_size=32)

and I want to fine a text classification with:

model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name, num_labels=2).to(device)

training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    weight_decay=0.01,
    max_steps=int(1e6),
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
)

trainer.train()

However I get

TypeError: object of type 'IterableDataset' has no len()

I’ve tried looking around but its not quite clear what I am doing wrong.

I am using

transformers==4.11.3
datasets==2.0.0

Any help would be appreciated

Full stack trace is

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-12-e7894bd2c657> in <module>
      9 )
     10 
---> 11 trainer.train()

~/.pyenv/versions/3.6.13/lib/python3.6/site-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1288             self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
   1289 
-> 1290             for step, inputs in enumerate(epoch_iterator):
   1291 
   1292                 # Skip past any already trained steps if resuming training

~/.pyenv/versions/3.6.13/lib/python3.6/site-packages/torch/utils/data/dataloader.py in __next__(self)
    519             if self._sampler_iter is None:
    520                 self._reset()
--> 521             data = self._next_data()
    522             self._num_yielded += 1
    523             if self._dataset_kind == _DatasetKind.Iterable and \

~/.pyenv/versions/3.6.13/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    558 
    559     def _next_data(self):
--> 560         index = self._next_index()  # may raise StopIteration
    561         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    562         if self._pin_memory:

~/.pyenv/versions/3.6.13/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _next_index(self)
    510 
    511     def _next_index(self):
--> 512         return next(self._sampler_iter)  # may raise StopIteration
    513 
    514     def _next_data(self):

~/.pyenv/versions/3.6.13/lib/python3.6/site-packages/torch/utils/data/sampler.py in __iter__(self)
    224     def __iter__(self) -> Iterator[List[int]]:
    225         batch = []
--> 226         for idx in self.sampler:
    227             batch.append(idx)
    228             if len(batch) == self.batch_size:

~/.pyenv/versions/3.6.13/lib/python3.6/site-packages/torch/utils/data/sampler.py in __iter__(self)
     64 
     65     def __iter__(self) -> Iterator[int]:
---> 66         return iter(range(len(self.data_source)))
     67 
     68     def __len__(self) -> int:

TypeError: object of type 'IterableDataset' has no len()

Hi Eric - you need to format your dataset for PyTorch first, like so: torch_iterable_dataset = dataset.with_format("torch").

More info on how and why can be found here: Error iteration over IterableDataset using Torch DataLoader · Issue #2583 · huggingface/datasets · GitHub

Hope that helps!

Cheers
Heiko

2 Likes

Thanks a lot @marshmellow77 its working now - I had seen that issue but I guess I didn’t read in detail enough!