TypeError using Accelerate with PyTorch Geometric

Hello everyone, I’m working on training a graph neural network with PyTorch Geometric. I have a dataset with hundreds of millions of rows. I’m attempting to train it on a single machine that has 3 Nvidia RTX 3090s. This error does not occur when I use a single GPU. The error is

 File "/home/username/projects/venv/lib/python3.11/site-packages/accelerate/data_loader.py", line 639, in __iter__
    next_batch, next_batch_info = self._fetch_batches(main_iterator)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/username/projects/venv/lib/python3.11/site-packages/accelerate/data_loader.py", line 602, in _fetch_batches
    batch = concatenate(batches, dim=0)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/username/projects/venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 530, in concatenate
    return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/username/projects/venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 530, in <dictcomp>
    return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/username/projects/venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 532, in concatenate
    raise TypeError(f"Can only concatenate tensors but got {type(data[0])}")
TypeError: Can only concatenate tensors but got <class 'torch_geometric.data.batch.HeteroDataBatch'>

Is there any way that I can use accelerate with this type of data? I found this on GitHub, but it seems to work for them. Torch Geometric compatibility · Issue #51 · huggingface/accelerate · GitHub

Hi @irow, can you provide a reproducer ? Currently, DataLoaderDispatcher (one gpu process and broadcast to the other) needs to concat the batches to send it to the other process and it only works with tensors for now. I suggest you to set dispatch_batches to False. This why, each gpu will process their own data.