- How to get current available memory in TPU, like nvidia-smi? I tried to use bfloat in TPU, but doesn’t seem to work. Is there ways to get TPU memory info from other process? (tensorboard didn’t work, not sure if I did something wrong, and pprof looks like it can only profile usage of the process itself. I am getting paranoid and thinking there might be TPU memory used by dead process.
I decided to use GPU to debug and get an idea of how much memory is used. I tried adafactor, and float16, and the memory saving seem small. (I used XLA_PYTHON_CLIENT_ALLOCATOR=platform
, XLA_PYTHON_CLIENT_PREALLOCATE=false
so I can see the actual usage.)
- How does float16/bfloat16 works? I printed the params and it’s float32, is it expected behavior? Here is the code:
import jax.numpy as jnp
from transformers import FlaxT5ForConditionalGeneration, AutoConfig
import jax
config = AutoConfig.from_pretrained('t5-base')
model = FlaxT5ForConditionalGeneration(config, seed=0, dtype=jnp.float16)
print(jax.tree_map(lambda x: x.dtype, model.params))
output: (trimmed)
{'decoder': {'block': {'0': {'layer': {'0': {'SelfAttention': {'k': {'kernel': dtype('float32')},
'o': {'kernel': dtype('float32')},
'q': {'kernel': dtype('float32')},
'relative_attention_bias': {'embedding': dtype('float32')},
- another float16/bfloat16 question: In the google t5x (flax t5) code, there is this line:
if use_bfloat16:
grad = jax.tree_map(lambda x: x.astype(jnp.bfloat16), grad)
But huggingface example doesn’t have this line, which one is “correct”? Is there a difference?
- another question is related to training speed: I am training t5-base, with per_device_batch_size=16, which means batch_size=16*8=128 and max_seq_len=512 (same as t5 paper) and t5 paper trains 524,288 steps, my progress bar says it would take 8 days for that many steps, is this good?