Hi everyone,
I have a large-ish dataset that I am loading with something like:
dataset_train = load_dataset(
'json',
data_files=...,
split='train',
streaming=True,
)
def preprocess_function(examples):
return tokenizer(examples["text"], examples["text_pair"], truncation=True)
tokenized_dataset = dataset_train.map(preprocess_function, batched=True, batch_size=32)
and I want to fine a text classification with:
model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name, num_labels=2).to(device)
training_args = TrainingArguments(
output_dir="./results",
learning_rate=2e-5,
per_device_train_batch_size=32,
weight_decay=0.01,
max_steps=int(1e6),
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
)
trainer.train()
However I get
TypeError: object of type 'IterableDataset' has no len()
I’ve tried looking around but its not quite clear what I am doing wrong.
I am using
transformers==4.11.3
datasets==2.0.0
Any help would be appreciated
Full stack trace is
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-12-e7894bd2c657> in <module>
9 )
10
---> 11 trainer.train()
~/.pyenv/versions/3.6.13/lib/python3.6/site-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1288 self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
1289
-> 1290 for step, inputs in enumerate(epoch_iterator):
1291
1292 # Skip past any already trained steps if resuming training
~/.pyenv/versions/3.6.13/lib/python3.6/site-packages/torch/utils/data/dataloader.py in __next__(self)
519 if self._sampler_iter is None:
520 self._reset()
--> 521 data = self._next_data()
522 self._num_yielded += 1
523 if self._dataset_kind == _DatasetKind.Iterable and \
~/.pyenv/versions/3.6.13/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _next_data(self)
558
559 def _next_data(self):
--> 560 index = self._next_index() # may raise StopIteration
561 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
562 if self._pin_memory:
~/.pyenv/versions/3.6.13/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _next_index(self)
510
511 def _next_index(self):
--> 512 return next(self._sampler_iter) # may raise StopIteration
513
514 def _next_data(self):
~/.pyenv/versions/3.6.13/lib/python3.6/site-packages/torch/utils/data/sampler.py in __iter__(self)
224 def __iter__(self) -> Iterator[List[int]]:
225 batch = []
--> 226 for idx in self.sampler:
227 batch.append(idx)
228 if len(batch) == self.batch_size:
~/.pyenv/versions/3.6.13/lib/python3.6/site-packages/torch/utils/data/sampler.py in __iter__(self)
64
65 def __iter__(self) -> Iterator[int]:
---> 66 return iter(range(len(self.data_source)))
67
68 def __len__(self) -> int:
TypeError: object of type 'IterableDataset' has no len()