Using an IterableDataset for 1+ epochs in Trainer

I have an IterableDataset. It is finite and has a __len__() method. However, I cannot use it for more than 1 epoch in Trainerdue to the dataset has been iterated through and no more samples left. With num_train_epochs > 1 in TrainingArguments, after one epoch, I get the error ValueError: Batch does not contain any data (None). At the end of all iterable data available before expected stop iteration.

Can someone help me?

Here is my code:

from typing import List, Dict, Tuple, Literal, Any
import torch 
import datasets

from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding

tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")

class DataStreamer(torch.utils.data.IterableDataset):
    def __init__(self):
        # example data 
        self.ds = datasets.Dataset.from_dict(
            {
                "text1":["A", "B", "C", "D", "E"], 
                "label":[0, 0, 0, 1, 1]
                }
            ) 
        
        self.ds = self.ds.map(lambda x: tokenizer(x["text1"], truncation=True, padding="max_length", max_length=20, return_tensors="pt"), batched=True)

        self.epoch_finish = False
        self.sample_pointer = 0

    def set_epoch(self, epoch):
        self.epoch = epoch

    def __iter__(self):
        while not self.epoch_finish: 
            sample = self.ds.shuffle()[self.sample_pointer]
            self.sample_pointer += 1

            if self.sample_pointer == len(self.ds): # dataset exhausted
                self.epoch_finish = True
            yield sample

    def __len__(self):
        return len(self.ds)

training_args = TrainingArguments(
    output_dir='./results',          
    num_train_epochs=5,  # only works when 1 
    per_device_train_batch_size=1,
    report_to="none"
)

trainer = Trainer(
    model=model, 
    args=training_args,
    data_collator= DataCollatorWithPadding(tokenizer=tokenizer),
    train_dataset=DataStreamer()
)

trainer.train()
1 Like

Maybe this?

Thanks @John6666.

with_format() won’t work in my case because it is a method for the class datasets.IterableDataset while my class DataStreamer is inherited from the class torch.utils.data.IterableDataset.

The solution that eventually worked was in the first page you cited (which I visited before I asked but I focused on the solution instead of the question). I just wrap the object DataStreamer() using torchdata.datapipes.iter.IterableWrapper().

from torchdata.datapipes.iter import IterDataPipe, IterableWrapper 
# must be torchdata<=0.9.0

trainer = Trainer(
    model=model, 
    args=training_args,
    data_collator= DataCollatorWithPadding(tokenizer=tokenizer),
    # train_dataset=DataStreamer().with_format("torch")
    train_dataset=IterableWrapper(DataStreamer()) 
)

For some reason, torchdata.nodes.IterableWrapper in torchdata>=0.10 (which is supposed to replace torchdata.datapipes.iter.IterableWrapper in <=0.9) does not work.

I will try to use datasets.IterableDataset and initialize it with a from_generator() method.

1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.