Problem in training iterable dataset

I am using PyTorch DDP (Distributed Data Parallel) to train my model. Since the data is too large to load into memory at once, I am using load_dataset to read the data as an iterable dataset. I have used datasets.distributed.split_dataset_by_node to distribute the dataset. However, I have noticed that this distribution results in different processes having different amounts of data to train on. As a result, when the earliest process finishes training and starts predicting on the test set, other processes are still training, causing the overall training speed to be very slow.

Here’s the code.

def train(args, model, device, train_loader, optimizer, criterion, epoch, length):
    model.train()
    idx_length = 0
    for batch_idx, data in enumerate(train_loader):
        s_time = time.time()
        X = data['X']
        target = data['y'].reshape(-1, 28)
        X, target = X.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(X)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        idx_length += 1
        if batch_idx % args.log_interval == 0:
            # print('Train Epoch: {} Batch_idx: {} Process: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            #     epoch, batch_idx, torch.distributed.get_rank(), batch_idx * len(X), length / torch.distributed.get_world_size(),
            #                                          100. * batch_idx * len(
            #                                              X) * torch.distributed.get_world_size() / length, loss.item()))
            print('Train Epoch: {} Batch_idx: {} Process: {} [{}/{} ({:.0f}%)]\t'.format(
                epoch, batch_idx, torch.distributed.get_rank(), batch_idx * len(X), length / torch.distributed.get_world_size(),
                                                                100. * batch_idx * len(
                                                                    X) * torch.distributed.get_world_size() / length))
            if args.dry_run:
                break
    print('Process %s length: %s time: %s' % (torch.distributed.get_rank(), idx_length, datetime.datetime.now()))

train_iterable_dataset = load_dataset("parquet", data_files=data_files, split="train", streaming=True)
    test_iterable_dataset = load_dataset("parquet", data_files=data_files, split="test", streaming=True)
    train_iterable_dataset = train_iterable_dataset.map(process_fn)
    test_iterable_dataset = test_iterable_dataset.map(process_fn)
    train_iterable_dataset = train_iterable_dataset.map(scale)
    test_iterable_dataset = test_iterable_dataset.map(scale)

    train_iterable_dataset = datasets.distributed.split_dataset_by_node(train_iterable_dataset,
                                                                        world_size=world_size, rank=local_rank).shuffle(seed=1234)
    test_iterable_dataset = datasets.distributed.split_dataset_by_node(test_iterable_dataset,
                                                                       world_size=world_size, rank=local_rank).shuffle(seed=1234)
    print(torch.distributed.get_rank(), train_iterable_dataset.n_shards, test_iterable_dataset.n_shards)

    train_kwargs = {'batch_size': args.batch_size}
    test_kwargs = {'batch_size': args.test_batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': 3,#ngpus_per_node,
                       'pin_memory': True,
                       'shuffle': False}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)
   train_loader = torch.utils.data.DataLoader(train_iterable_dataset, **train_kwargs,
                                               # sampler=torch.utils.data.distributed.DistributedSampler(
                                               #     train_iterable_dataset,
                                               #     num_replicas=ngpus_per_node,
                                               #     rank=0)
                                               )
    test_loader = torch.utils.data.DataLoader(test_iterable_dataset, **test_kwargs,
                                              # sampler=torch.utils.data.distributed.DistributedSampler(
                                              #     test_iterable_dataset,
                                              #     num_replicas=ngpus_per_node,
                                              #     rank=0)
                                              )
        for epoch in range(1, args.epochs + 1):
          start_time = time.time()
          train_iterable_dataset.set_epoch(epoch)
          test_iterable_dataset.set_epoch(epoch)
          train(args, model, device, train_loader, optimizer, criterion, epoch, train_len)
          test(args, model, device, criterion2, test_loader)

And here’s the part of output:

Train Epoch: 1 Batch_idx: 5000 Process: 0 [320000/4710975.0 (7%)]	
Train Epoch: 1 Batch_idx: 5000 Process: 1 [320000/4710975.0 (7%)]	
Train Epoch: 1 Batch_idx: 5000 Process: 2 [320000/4710975.0 (7%)]	
Train Epoch: 1 Batch_idx: 5862 Process: 3 Data_length: 12 coststime: 0.04095172882080078
Train Epoch: 1 Batch_idx: 5862 Process: 0 Data_length: 3 coststime: 0.0751960277557373
Train Epoch: 1 Batch_idx: 5867 Process: 3 Data_length: 49 coststime: 0.0032558441162109375
Train Epoch: 1 Batch_idx: 5872 Process: 1 Data_length: 2 coststime: 0.022842884063720703
Train Epoch: 1 Batch_idx: 5876 Process: 3 Data_length: 63 coststime: 0.002694845199584961
Process 3 length: 5877 time: 2023-11-17 17:03:26.582317
Train epoch 1 costTime: 241.72063446044922s . Process 3 Start to test.
3 0 tensor(45508.8516, device='cuda:3')
3 100 tensor(45309.0469, device='cuda:3')
3 200 tensor(45675.3047, device='cuda:3')
3 300 tensor(45263.0273, device='cuda:3')
Process 3 Reduce metrics.
Train Epoch: 2 Batch_idx: 0 Process: 3 [0/4710975.0 (0%)]	
Train Epoch: 1 Batch_idx: 5882 Process: 1 Data_length: 63 coststime: 0.05185818672180176
Train Epoch: 1 Batch_idx: 5887 Process: 1 Data_length: 12 coststime: 0.006895303726196289
Process 1 length: 5888 time: 2023-11-17 17:20:48.578204
Train epoch 1 costTime: 1285.7279663085938s . Process 1 Start to test.
1 0 tensor(45265.9141, device='cuda:1')

I have tried to add barrier(), but the code hang till timeout.

In my case, IterableDatasetShard helped.
Codes look like:

train_dataset = IterableDataset()
train_dataset = IterableDatasetShard(train_dataset)
train_dataloader = Dataloader(train_dataset)