Training with IterableDataset is very slow when using a large number of workers

My problem is:

I have 512 JSONL files (each 1GB) and stored on the Amazon Cloud. I have 32 GPUs. I read this big dataset with IterableDataSet recommended by Huggingface so that I can train on the fly.

dataset = load_dataset(“json”, data_files=data_files, storage_options=storage_options, streaming=True)

But when doing training I found that the speed was very slow, especially when I use many GPU workers .I found that the reason was because HuggingFace Trainer used DispatchDataloader to read Iterable DataSet.
It is the same with huggingface/accelerate#158.

Is there a good solution for my problem?
One of the solutions I think is to divide the 512 JSONL files to 32 GPU workers when building the dataset. So each GPU Worker only accesses the corresponding 16 JSONL file. it seems that I should shard the dataset manually in my code according to gpu worker id, and I should not pass the dataloader to accelerator.prepare().

Can HuggingFace Trainer currently support this way of data loading? Or is there any other way to deal with my problem