I have a fairly large TensorDataset in memory (2GB) and whenever I try to use accelerate it hangs just as it loads the dataset (for batch in train_loader:
). See below for a reproducible example – when the dataset is 5k rows it works OK but when I bring it up to 50k it just hangs… also this works fine when num_processes=1. What’s causing this and what can I do to fix it? Thanks!
from accelerate import Accelerator
from accelerate import notebook_launcher
import torch
#works ok with 5000 but if I bring it to 100000 it hangs...
latent_train_data = torch.randn(5000, 32, 32)
train_label_embeddings = torch.randn(5000, 128)
dataset = torch.utils.data.TensorDataset(latent_train_data, train_label_embeddings)
def training_loop(mixed_precision="fp16"):
accelerator = Accelerator(mixed_precision=mixed_precision)
batch_size = 256
n_epoch = 2
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
num_workers=4, shuffle=True, pin_memory=True)
print("model prep")
train_loader = accelerator.prepare(
train_loader
)
for i in range(1, n_epoch+1):
print(f'epoch: {i}')
for batch in train_loader:
x, y = batch
print('loader arghhh')
print(x.shape)
print(y.shape)
args = ("fp16",)
notebook_launcher(training_loop, args, num_processes=2)