FLAX - Training on Cloud TPU VM Pods (not single TPU devices)


I have been able to successfully finetune several summarization models on TPU VM single devices, such as v2-8 or v3-8, using FLAX scripts available here.

I would like now to scale training to larger TPU VM pods, such as v3-32.
Reading Cloud TPU VMs, it seems that the Python code needs to be copied to every worker VM, and launched locally.

It’s not clear to me if there should be a “master” agent, doing the synchronization between workers, and “local” device agents (each of them doing the training on a slice of the dataset with its own local tpu device).

I was wondering if the example Flax scripts provided in the Transformers library are able to work with v3-32 or larger TPU VM pods, and if someone could share a quick tutorial or few tips to make it work.

Thanks !

Hey @bduclaux,

You’re entirely right, you’ll need to copy all python code to every TPU worker. Because JAX runs synchronously across TPU workers, you’ll need to launch the same command on each worker too.

We recently looked into this ourselves with the BLOOM :cherry_blossom: project! Although we were just running inference, the same principles apply for training on a pod. Our set-up was:

  • CPU host: SSH into the CPU host and work from there. The CPU host dispatches commands to the TPU workers and receives the TPU outputs.
  • TPU worker: communicates with the CPU host. Receives commands from the CPU host and returns the outputs.

We found the most straightforward way of achieving this was through the use of ray. Essentially, what ray lets you do is define a python class for the CPU host and a python class for the TPU workers. When it comes to running the scripts, all you have to do is call the relevant methods for each.

We’ve open-sourced the repo we developed here: GitHub - huggingface/bloom-jax-inference
This repo makes use of pjit to achieve tensor parallelism across TPU devices. This is in contrast to pmap from the example script, which is solely for data parallelism. The advantage of pjit is that you don’t need to duplicate your entire model across TPU devices. In fact, you “slice” up the weight matrices and place a portion on each device, meaning each TPU worker only holds a fraction of the model. This enables you to train much larger models, as the weights are shared across devices.

It’s our plan to document this repo and write a blog explaining how it all works in the comings weeks :slight_smile:
I’ve also developed a branch for use on a v3-32 (if of interest): GitHub - huggingface/bloom-jax-inference at v3-32

If you’re sticking with pmap, you’ll just need the host_worker.py and tpu_manager.py classes to get going. You can run everything end-to-end using the run.py script! If using pjit, you can take a look at generator.py and the modelling code included in the repo (2D model, activation and data parallelism for a 176b param model).

As I mentioned, we’ll be releasing a blog and docs over the next couple of weeks, which should help in providing a more in-depth tutorial. In the mean time, feel free to ping me on here with any specific questions!