Ideas for beginner-friendlier TPU-VM clm training

Hello All,

I’ve recently started trying out TPU-VMs and wanted to train distilgpt2 from scratch on an non English language. I’ve had a rather rough start but did manage to overcome and get the training going.
Following the advice from Suraj Patil, I decided to write a list of things which can make this experience a bit smoother for others.
If you too think these items might help, perhaps some could be added to the relevant tutorials, docs or even the example code itself.

GIT credentials for Huggingface

  • Perhaps explain / demonstrate how to store and cache them? (Extra useful if the model is “Private”)
  • Because I did not cache properly at first, and because I had junk in the keyboard buffer, at some point when the script wanted to “Push to hub”, it failed. Now it seems like it does not retry, nor did ‘run_clm_flax.py’ saves the checkpoint to disk when such an error happen, so basically all the hours of training were lost.
  • I’d suggest that if the training scripts have failed auth for some reason, they should not exit before it has written the checkpoint to disk.
  • Is it possible to somehow make sure that we have cached GIT credentials when the training starts?

Eval phase

  • Same as above, in case of an error, do not exit before you write the checkpoint to avoid loss of training time

Log

  • Is there a built-in option to have all the “on screen” data be written to a log file? - If so, I couldn’t find it
  • Why are all of those “tcmalloc” lines being printed (even on “Info” log level)? Is there a way to have them not print? - It really clutters the screen to the point where it’s hard to find the actual log prints.
  • I “worked around” it by piping everything to grep -v “tcmalloc” but, is there a better solution?

dtype

  • float16 - Doesn’t really work well with TPU, right ? Perhaps not allow it if training with JAX?
  • bfloat16 - Can it convert to pytorch? Only after I had my first checkpoint, I discovered that I can not use Flax models it with the online inference-box and that AutoModelForCausalLM fails to load bfloat16 ones
  • float32 - I decided to re-start with float32 as it was loaded well by AutoModelForCausalLM and was saved well as pytorch. Is this the only/best option if I want to have my FLAX model also as pytorch?

Upgrade CLU

  • Somehow I missed this upgrade step (pip install --upgrade clu) which resulted in a lot of weird errors and malloc failures.
  • If it’s not just me, which it totally might be :slight_smile: Perhaps it should be more prominent?

Anyway,
This is my feedback,
I hope someone finds it useful.

Great work everyone!

  • Follow this tutorial to set-up and connect to your TPU-VM

  • Add your local bin path to the PATH environment variable, If you do not know your local user name, type:
    whoami
    #In my case, the user is ‘dadler’, so replace ‘dadler’ in the following block with your own user:
    nano ~/.bashrc
    #Add the following line at the bottom, replace dadler with your own user name
    export PATH="/home/dadler/.local/bin:$PATH"
    #Save (CTRL O) #Exit (Ctrl X)
    #Realod bashrc
    source ~/.bashrc

  • Install and upgrade libraries

pip install datasets
git clone https://github.com/huggingface/transformers.git
sudo pip install --user -e transformers
pip install --upgrade tokenizers
pip install --upgrade clu
git clone https://github.com/google/flax.git
sudo pip install --user -e flax
pip install git+https://github.com/deepmind/optax.git
  • Setup git-lfs
sudo apt install git-lfs
git lfs install
  • Login to your huggingface account
    huggingface-cli login

  • Save your git credentials on the local VM (Not secure, do this only if you are the only person who has access the TPU-VM instance)

git config --global credential.helper 'store --file ~/.git-credentials'
git credential fill

#Type/Paste the following two lines:

protocol=https
host="huggingface.co"

#Now hit Enter until you are prompt to enter your huggingface user and password

  • N̶o̶t̶ ̶s̶u̶r̶e̶ ̶i̶f̶ ̶t̶h̶i̶s̶ ̶i̶s̶ ̶n̶e̶e̶d̶e̶d̶,̶ ̶b̶u̶t̶ ̶i̶n̶ ̶t̶h̶e̶ ̶t̶e̶r̶m̶i̶n̶a̶l̶ ̶s̶e̶s̶s̶i̶o̶n̶ ̶y̶o̶u̶ ̶a̶r̶e̶ ̶a̶b̶o̶u̶t̶ ̶t̶o̶ ̶r̶u̶n̶ ̶y̶o̶u̶r̶ ̶t̶r̶a̶i̶n̶i̶n̶g̶ ̶s̶c̶r̶i̶p̶t̶ ̶i̶n̶,̶ ̶y̶o̶u̶ ̶m̶i̶g̶h̶t̶ ̶w̶a̶n̶t̶ ̶t̶o̶ ̶t̶y̶p̶e̶:̶
    ̶̶e̶x̶p̶o̶r̶t̶ ̶X̶R̶T̶_̶T̶P̶U̶_̶C̶O̶N̶F̶I̶G̶=̶"̶l̶o̶c̶a̶l̶s̶e̶r̶v̶i̶c̶e̶;̶0̶;̶l̶o̶c̶a̶l̶h̶o̶s̶t̶:̶5̶1̶0̶1̶1̶"̶̶

  • Continue by following the instructions in this tutorial

Hope it helps :slight_smile:

1 Like
  • Not sure if this is needed, but in the terminal session you are about to run your training script in, you might want to type:
    export XRT_TPU_CONFIG="localservice;0;localhost:51011"

AFAIK, this is needed if you are using pytorch/xla, and not needed for flax/jax.
You might need to do export USE_TORCH=False if you have torch installed but want to use flax.

1 Like

Ahh, thank you!