Using IterableDataset with Trainer - `IterableDataset' has no len()

Hi everyone,

I have a large-ish dataset that I am loading with something like:

dataset_train = load_dataset(
    'json',
    data_files=...,
    split='train',
    streaming=True,
)

def preprocess_function(examples):
    return tokenizer(examples["text"], examples["text_pair"], truncation=True)

tokenized_dataset = dataset_train.map(preprocess_function, batched=True, batch_size=32)

and I want to fine a text classification with:

model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name, num_labels=2).to(device)

training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    weight_decay=0.01,
    max_steps=int(1e6),
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
)

trainer.train()

However I get

TypeError: object of type 'IterableDataset' has no len()

I’ve tried looking around but its not quite clear what I am doing wrong.

I am using

transformers==4.11.3
datasets==2.0.0

Any help would be appreciated

Full stack trace is

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-12-e7894bd2c657> in <module>
      9 )
     10 
---> 11 trainer.train()

~/.pyenv/versions/3.6.13/lib/python3.6/site-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1288             self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
   1289 
-> 1290             for step, inputs in enumerate(epoch_iterator):
   1291 
   1292                 # Skip past any already trained steps if resuming training

~/.pyenv/versions/3.6.13/lib/python3.6/site-packages/torch/utils/data/dataloader.py in __next__(self)
    519             if self._sampler_iter is None:
    520                 self._reset()
--> 521             data = self._next_data()
    522             self._num_yielded += 1
    523             if self._dataset_kind == _DatasetKind.Iterable and \

~/.pyenv/versions/3.6.13/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    558 
    559     def _next_data(self):
--> 560         index = self._next_index()  # may raise StopIteration
    561         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    562         if self._pin_memory:

~/.pyenv/versions/3.6.13/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _next_index(self)
    510 
    511     def _next_index(self):
--> 512         return next(self._sampler_iter)  # may raise StopIteration
    513 
    514     def _next_data(self):

~/.pyenv/versions/3.6.13/lib/python3.6/site-packages/torch/utils/data/sampler.py in __iter__(self)
    224     def __iter__(self) -> Iterator[List[int]]:
    225         batch = []
--> 226         for idx in self.sampler:
    227             batch.append(idx)
    228             if len(batch) == self.batch_size:

~/.pyenv/versions/3.6.13/lib/python3.6/site-packages/torch/utils/data/sampler.py in __iter__(self)
     64 
     65     def __iter__(self) -> Iterator[int]:
---> 66         return iter(range(len(self.data_source)))
     67 
     68     def __len__(self) -> int:

TypeError: object of type 'IterableDataset' has no len()

Hi Eric - you need to format your dataset for PyTorch first, like so: torch_iterable_dataset = dataset.with_format("torch").

More info on how and why can be found here: Error iteration over IterableDataset using Torch DataLoader · Issue #2583 · huggingface/datasets · GitHub

Hope that helps!

Cheers
Heiko

4 Likes

Thanks a lot @marshmellow77 its working now - I had seen that issue but I guess I didn’t read in detail enough!

Here is an example using @eloaf 's snippet

class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops
        self.ENCOUNTERS = encounters

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        text = tokenizer.decode(input_ids[0])
        for stop in self.stops:
            if len(text.split(stop))-1 > self.ENCOUNTERS:
                return True
        return False

This seems to solve the problem. Is it supposed to work also for audio datasets in streaming?

dataset = load_dataset("polinaeterna/vox_lingua", "all", split="train", streaming=True)
dataset = dataset.with_format("torch")
assert isinstance(dataset, torch.utils.data.IterableDataset) #it passes
print(f"VL107_HF: {len(dataset)}") # TypeError: object of type 'IterableDataset' has no len()

It gives me the same error as the original question.

That error is expected. You cannot call len on an iterable dataset. You can however get batches from it in order to train a model without having to load the data into your disk.

I am having the same issue, even with with_format_torch

self.dataset = {dataset_key: load_dataset(self.hparams.dataset_name, 
                                                      cache_dir=self.hparams.data_dir, 
                                                      split=dataset_key, streaming=True).with_format("torch") for dataset_key in self.data_splits}
assert isinstance(self.dataset['train'], torch.utils.data.IterableDataset) #it passes 
print(f"Train dataset length: {len(self.dataset['train'])}")
...
TypeError: object of type 'IterableDataset' has no len()