How to use Huggingface Trainer streaming Datasets without wrapping it with torchdata's IterableWrapper?

Given a datasets.iterable_dataset.IterableDataset with stream=True, e.g.

train_data = load_dataset("csv", data_files="../input/tatoeba/tatoeba-sentpairs.tsv", 
                  streaming=True, delimiter="\t", split="train")

and trying to use it in a Trainer object, e.g.

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=multibert,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=train_data,
)

trainer.train()

It throws an error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_27/3002801805.py in <module>
     28 )
     29 
---> 30 trainer.train()

/opt/conda/lib/python3.7/site-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1411             resume_from_checkpoint=resume_from_checkpoint,
   1412             trial=trial,
-> 1413             ignore_keys_for_eval=ignore_keys_for_eval,
   1414         )
   1415 

/opt/conda/lib/python3.7/site-packages/transformers/trainer.py in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1623 
   1624             step = -1
-> 1625             for step, inputs in enumerate(epoch_iterator):
   1626 
   1627                 # Skip past any already trained steps if resuming training

/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    528             if self._sampler_iter is None:
    529                 self._reset()
--> 530             data = self._next_data()
    531             self._num_yielded += 1
    532             if self._dataset_kind == _DatasetKind.Iterable and \

/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    567 
    568     def _next_data(self):
--> 569         index = self._next_index()  # may raise StopIteration
    570         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    571         if self._pin_memory:

/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_index(self)
    519 
    520     def _next_index(self):
--> 521         return next(self._sampler_iter)  # may raise StopIteration
    522 
    523     def _next_data(self):

/opt/conda/lib/python3.7/site-packages/torch/utils/data/sampler.py in __iter__(self)
    224     def __iter__(self) -> Iterator[List[int]]:
    225         batch = []
--> 226         for idx in self.sampler:
    227             batch.append(idx)
    228             if len(batch) == self.batch_size:

/opt/conda/lib/python3.7/site-packages/torch/utils/data/sampler.py in __iter__(self)
     64 
     65     def __iter__(self) -> Iterator[int]:
---> 66         return iter(range(len(self.data_source)))
     67 
     68     def __len__(self) -> int:

TypeError: object of type 'IterableDataset' has no len()

This can be resolved by wrapping the IterableDataset object with the IterableWrapper from torchdata library.

from torchdata.datapipes.iter import IterDataPipe, IterableWrapper

...


# instantiate trainer
trainer = Seq2SeqTrainer(
    model=multibert,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=IterableWrapper(train_data),
    eval_dataset=IterableWrapper(train_data),
)

trainer.train()

Is it possible to use the IterableDataset with Seq2SeqTrainer without casting it with IterableWrapper?


For reference, a full working code would look something as below, replacing the line where train_dataset=IterableWrapper(train_data) to train_dataset=train_data will replicate the TypeError: object of type 'IterableDataset' has no len() error.

import torch

from datasets import load_dataset
from transformers import EncoderDecoderModel
from transformers import AutoTokenizer
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

from torchdata.datapipes.iter import IterDataPipe, IterableWrapper

multibert = EncoderDecoderModel.from_encoder_decoder_pretrained(
    "bert-base-multilingual-uncased", "bert-base-multilingual-uncased"
)
tokenizer= AutoTokenizer.from_pretrained("bert-base-multilingual-uncased")
tokenizer.bos_token = tokenizer.cls_token
tokenizer.eos_token = tokenizer.sep_token
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# set special tokens
multibert.config.decoder_start_token_id = tokenizer.bos_token_id
multibert.config.eos_token_id = tokenizer.eos_token_id
multibert.config.pad_token_id = tokenizer.pad_token_id

# sensible parameters for beam search
multibert.config.vocab_size = multibert.config.decoder.vocab_size

def process_data_to_model_inputs(batch, max_len=10): 
    inputs = tokenizer(batch["SRC"], padding="max_length",
                       truncation=True, max_length=max_len)
    outputs = tokenizer(batch["TRG"], padding="max_length", 
                        truncation=True, max_length=max_len)

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    batch["decoder_input_ids"] = outputs.input_ids
    batch["decoder_attention_mask"] = outputs.attention_mask
    batch["labels"] = outputs.input_ids.copy()

    # because BERT automatically shifts the labels, the labels correspond exactly to `decoder_input_ids`. 
    # We have to make sure that the PAD token is ignored
    batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]
    
    return batch


# tatoeba-sentpairs.tsv is a pretty large file.
train_data = load_dataset("csv", data_files="../input/tatoeba/tatoeba-sentpairs.tsv", 
                  streaming=True, delimiter="\t", split="train")

train_data = ds.map(process_data_to_model_inputs, batched=True)



batch_size = 1

# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
    output_dir="./",
    evaluation_strategy="steps",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    predict_with_generate=True,
    logging_steps=2,  # set to 1000 for full training
    save_steps=16,    # set to 500 for full training
    eval_steps=4,     # set to 8000 for full training
    warmup_steps=1,   # set to 2000 for full training
    max_steps=16,     # delete for full training
    # overwrite_output_dir=True,
    save_total_limit=1,
    #fp16=True, 
)


# instantiate trainer
trainer = Seq2SeqTrainer(
    model=multibert,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=IterableWrapper(train_data),
    eval_dataset=IterableWrapper(train_data),
)

trainer.train()

P/S: Also asked on python - How to use Huggingface Trainer streaming Datasets without wrapping it with torchdata's IterableWrapper? - Stack Overflow

Found the answer from Using IterableDataset with Trainer - `IterableDataset' has no len()

By adding the with format to the iterable dataset, like this:

train_data.with_format("torch")

The trainer should work without throwing the len() error.

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=multibert,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_data.with_format("torch"),
    eval_dataset=train_data.with_format("torch"),
)

trainer.train()
3 Likes