Running out of memory processing dataset

see: Stream

i had a similar issue where preprocessing of the dataset would just fill up the memory and oom’d. i solved it by using IterableDataset but i got the feeling it wasn’t desirable. i feel like preprocessing (specifically .map() fn) is for small, not memory-intensive operations like tokenization, and not for loading up large datasets into memory like images. i remember doing something hacky like using transform/augmentation to load an image on the fly while only storing the path but all of this is my limited experience, because .map() supposedly shouldn’t lead to oom because it doesn’t load dataset all at once. i just got frustrated at some point and decided to not figure out the right way.

you can shuffle an iterable/streaming dataset, see above link, also look into trainer callbacks where you can invoke a reshuffle after each epoch. i haven’t tested below code but something like this should work:

class ShuffleCallback(TrainerCallback):
    def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
        if isinstance(train_dataloader.dataset, IterableDataset):
            train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)

and

trainer_object = Trainer(
...
callbacks=[ShuffleCallback()],
)
1 Like