How to use Huggingface Trainer streaming Datasets without wrapping it with torchdata's IterableWrapper?

Found the answer from Using IterableDataset with Trainer - `IterableDataset' has no len()

By adding the with format to the iterable dataset, like this:

train_data.with_format("torch")

The trainer should work without throwing the len() error.

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=multibert,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_data.with_format("torch"),
    eval_dataset=train_data.with_format("torch"),
)

trainer.train()
3 Likes