I am running the run_mlm_flax-scripts on TPUv4. Will the parameter per_device_batch_size automatically scale when running on pods, or do the patitions needs to be defined? And how? To clarify: Lets say I am running a script with per_devic_batch_size=100 on 4 TPUs (a TPUv4-32). How big will the actual batch size be?
Related topics
Topic | Replies | Views | Activity | |
---|---|---|---|---|
Trainer with TPUs | 3 | 2818 | April 13, 2022 | |
FLAX - Training on Cloud TPU VM Pods (not single TPU devices) | 1 | 1418 | August 2, 2022 | |
How to specify different batch sizes for different GPUs when training with rum_mlm.py? | 1 | 1109 | July 26, 2021 | |
How to calculate the effective batch size on TPU? | 2 | 2194 | September 1, 2021 | |
Flax/Jax/TPU questions | 1 | 1607 | July 5, 2021 |