Bigger batch size, the lower throughput and GPU usageļ¼Ÿ

Iā€™m training bert-base on single node with 8xA100 GPU, using run_mlm.py script.

when the batch size set to 256, the throughput is 8000 sample/s, and GPU usage is 80%.
when the batch size set to 384, the throughput is 4000 sample/s, and GPU usage is 50%.

What is the reason for this phenomenonļ¼ŸData IO became bottleneck?