Flax/Jax/TPU questions

  • How to get current available memory in TPU, like nvidia-smi? I tried to use bfloat in TPU, but doesn’t seem to work. Is there ways to get TPU memory info from other process? (tensorboard didn’t work, not sure if I did something wrong, and pprof looks like it can only profile usage of the process itself. I am getting paranoid and thinking there might be TPU memory used by dead process.

I decided to use GPU to debug and get an idea of how much memory is used. I tried adafactor, and float16, and the memory saving seem small. (I used XLA_PYTHON_CLIENT_ALLOCATOR=platform, XLA_PYTHON_CLIENT_PREALLOCATE=false so I can see the actual usage.)

  • How does float16/bfloat16 works? I printed the params and it’s float32, is it expected behavior? Here is the code:
import jax.numpy as jnp
from transformers import FlaxT5ForConditionalGeneration, AutoConfig
import jax

config = AutoConfig.from_pretrained('t5-base')
model = FlaxT5ForConditionalGeneration(config, seed=0, dtype=jnp.float16)
print(jax.tree_map(lambda x: x.dtype, model.params))

output: (trimmed)

{'decoder': {'block': {'0': {'layer': {'0': {'SelfAttention': {'k': {'kernel': dtype('float32')},
       'o': {'kernel': dtype('float32')},
       'q': {'kernel': dtype('float32')},
       'relative_attention_bias': {'embedding': dtype('float32')},
  • another float16/bfloat16 question: In the google t5x (flax t5) code, there is this line:
    if use_bfloat16:
      grad = jax.tree_map(lambda x: x.astype(jnp.bfloat16), grad)

But huggingface example doesn’t have this line, which one is “correct”? Is there a difference?

  • another question is related to training speed: I am training t5-base, with per_device_batch_size=16, which means batch_size=16*8=128 and max_seq_len=512 (same as t5 paper) and t5 paper trains 524,288 steps, my progress bar says it would take 8 days for that many steps, is this good?
1 Like

A quick note on reduced precision training:

Params are typically still stored in float32. But some intermediate computations can be done in bfloat16. Most standard Flax modules have a dtype attribute, that you can set to bfloat16 to compute the safe parts in reduced precision.

1 Like