I am running the run_mlm_flax-scripts on TPUv4. Will the parameter per_device_batch_size automatically scale when running on pods, or do the patitions needs to be defined? And how? To clarify: Lets say I am running a script with per_devic_batch_size=100 on 4 TPUs (a TPUv4-32). How big will the actual batch size be?
Related topics
Topic | Replies | Views | Activity | |
---|---|---|---|---|
How to specify different batch sizes for different GPUs when training with rum_mlm.py? | 1 | 1099 | July 26, 2021 | |
FLAX - Training on Cloud TPU VM Pods (not single TPU devices) | 1 | 1403 | August 2, 2022 | |
Incorrect total train batch size when using tp_size > 1 and deepspeed | 1 | 15 | May 20, 2025 | |
Pipeline device issue, torch_xla generation() bug, flax models malloc errors | 0 | 167 | April 21, 2024 | |
Per_device_train_batch_size in model parallelism | 2 | 27 | April 7, 2025 |