I’m trying to train a model on a large (600GB on disk) dataset of pre-computed embedding vectors.
Training is i/o bound (loading embeddings from disk). I’m trying to increase the speed of dataloading by using multiple workers, but running into memory issues.
Some benchmarks:
Using 32 workers, I get about 3hr/epoch but OOM about 45 minutes in
Using 0 workers, I get about 8hr/epoch with no memory issues
Using 32>n>0 workers slows down the rate of memory consumption, but doesn’t stop the core issue.
I believe this issue is caused by pytorch’s dataloader workers replicating memory in-process (link), but I don’t know how to stop that. Is there a fix to this, or a way to periodically restart pytorch dataloader workers to clear memory?
Some notes:
Datasets are formatted for pytorch before training
Changing to an iterable dataset slows down the rate of memory accumulation for the same number of workers, but doesn’t stop it
this is using:
torch==2.2.1
transformers==4.51.3
datasets==2.19.1
accelerate==1.6.0