Found the answer from Using IterableDataset with Trainer - `IterableDataset' has no len()
By adding the with format to the iterable dataset, like this:
train_data.with_format("torch")
The trainer should work without throwing the len()
error.
# instantiate trainer
trainer = Seq2SeqTrainer(
model=multibert,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_data.with_format("torch"),
eval_dataset=train_data.with_format("torch"),
)
trainer.train()