Accelerate stalls when using Tensor Dataset

I have a fairly large TensorDataset in memory (2GB) and whenever I try to use accelerate it hangs just as it loads the dataset (for batch in train_loader:). See below for a reproducible example – when the dataset is 5k rows it works OK but when I bring it up to 50k it just hangs… also this works fine when num_processes=1. What’s causing this and what can I do to fix it? Thanks!

from accelerate import Accelerator
from accelerate import notebook_launcher
import torch 

#works ok with 5000 but if I bring it to 100000 it hangs...
latent_train_data = torch.randn(5000, 32, 32)
train_label_embeddings = torch.randn(5000, 128)

dataset = torch.utils.data.TensorDataset(latent_train_data, train_label_embeddings)



def training_loop(mixed_precision="fp16"):
    accelerator = Accelerator(mixed_precision=mixed_precision)
    batch_size = 256
    n_epoch = 2
    
    
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 
                              num_workers=4, shuffle=True, pin_memory=True)
   

    print("model prep")
    train_loader = accelerator.prepare(
        train_loader
    )
    

    for i in range(1, n_epoch+1):
        print(f'epoch: {i}')

        for batch in train_loader:
            x, y = batch
            print('loader arghhh')
            print(x.shape)
            print(y.shape)
                        
            
args = ("fp16",)
notebook_launcher(training_loop, args, num_processes=2)