I’ve recently started trying out TPU-VMs and wanted to train distilgpt2 from scratch on an non English language. I’ve had a rather rough start but did manage to overcome and get the training going.
Following the advice from Suraj Patil, I decided to write a list of things which can make this experience a bit smoother for others.
If you too think these items might help, perhaps some could be added to the relevant tutorials, docs or even the example code itself.
GIT credentials for Huggingface
- Perhaps explain / demonstrate how to store and cache them? (Extra useful if the model is “Private”)
- Because I did not cache properly at first, and because I had junk in the keyboard buffer, at some point when the script wanted to “Push to hub”, it failed. Now it seems like it does not retry, nor did ‘run_clm_flax.py’ saves the checkpoint to disk when such an error happen, so basically all the hours of training were lost.
- I’d suggest that if the training scripts have failed auth for some reason, they should not exit before it has written the checkpoint to disk.
- Is it possible to somehow make sure that we have cached GIT credentials when the training starts?
- Same as above, in case of an error, do not exit before you write the checkpoint to avoid loss of training time
- Is there a built-in option to have all the “on screen” data be written to a log file? - If so, I couldn’t find it
- Why are all of those “tcmalloc” lines being printed (even on “Info” log level)? Is there a way to have them not print? - It really clutters the screen to the point where it’s hard to find the actual log prints.
- I “worked around” it by piping everything to grep -v “tcmalloc” but, is there a better solution?
- float16 - Doesn’t really work well with TPU, right ? Perhaps not allow it if training with JAX?
- bfloat16 - Can it convert to pytorch? Only after I had my first checkpoint, I discovered that I can not use Flax models it with the online inference-box and that AutoModelForCausalLM fails to load bfloat16 ones
- float32 - I decided to re-start with float32 as it was loaded well by AutoModelForCausalLM and was saved well as pytorch. Is this the only/best option if I want to have my FLAX model also as pytorch?
- Somehow I missed this upgrade step (pip install --upgrade clu) which resulted in a lot of weird errors and malloc failures.
- If it’s not just me, which it totally might be Perhaps it should be more prominent?
This is my feedback,
I hope someone finds it useful.
Great work everyone!
- Doron Adler