I have an IterableDataset. It is finite and has a __len__()
method. However, I cannot use it for more than 1 epoch in Trainer
due to the dataset has been iterated through and no more samples left. With num_train_epochs
> 1 in TrainingArguments
, after one epoch, I get the error ValueError: Batch does not contain any data (
None). At the end of all iterable data available before expected stop iteration.
Can someone help me?
Here is my code:
from typing import List, Dict, Tuple, Literal, Any
import torch
import datasets
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
class DataStreamer(torch.utils.data.IterableDataset):
def __init__(self):
# example data
self.ds = datasets.Dataset.from_dict(
{
"text1":["A", "B", "C", "D", "E"],
"label":[0, 0, 0, 1, 1]
}
)
self.ds = self.ds.map(lambda x: tokenizer(x["text1"], truncation=True, padding="max_length", max_length=20, return_tensors="pt"), batched=True)
self.epoch_finish = False
self.sample_pointer = 0
def set_epoch(self, epoch):
self.epoch = epoch
def __iter__(self):
while not self.epoch_finish:
sample = self.ds.shuffle()[self.sample_pointer]
self.sample_pointer += 1
if self.sample_pointer == len(self.ds): # dataset exhausted
self.epoch_finish = True
yield sample
def __len__(self):
return len(self.ds)
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=5, # only works when 1
per_device_train_batch_size=1,
report_to="none"
)
trainer = Trainer(
model=model,
args=training_args,
data_collator= DataCollatorWithPadding(tokenizer=tokenizer),
train_dataset=DataStreamer()
)
trainer.train()