FLAX - Training on Cloud TPU VM Pods (not single TPU devices)


I have been able to successfully finetune several summarization models on TPU VM single devices, such as v2-8 or v3-8, using FLAX scripts available here.

I would like now to scale training to larger TPU VM pods, such as v3-32.
Reading Cloud TPU VMs, it seems that the Python code needs to be copied to every worker VM, and launched locally.

It’s not clear to me if there should be a “master” agent, doing the synchronization between workers, and “local” device agents (each of them doing the training on a slice of the dataset with its own local tpu device).

I was wondering if the example Flax scripts provided in the Transformers library are able to work with v3-32 or larger TPU VM pods, and if someone could share a quick tutorial or few tips to make it work.

Thanks !