How to use TPU for model training using example script run_mlm.py

Hi there,

I’ve been trying to use colab TPU for training BERT and other language models using the example script provided ‘run_mlm.py’.

I’ve used the code for loading a TPU in colab from the original T5 repo here (Google Colab)

Code provides TPU_ADDRESS from TPU_ADDRESS = tpu.get_master(), which looks something like this: grpc://10.121.143.146:8470

As stated in the docs (https://github.com/huggingface/transformers/tree/master/examples/tensorflow/language-modeling#multi-gpu-and-tpu-usage) TPU can be used by passing the --tpu argument, however doing this I’m getting the following error:

!python /content/transformers/examples/tensorflow/language-modeling/run_mlm.py \
--model_name_or_path bert-base-cased \
--validation_split_percentage 20 \
--line_by_line \
--learning_rate 2e-5 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 32 \
--num_train_epochs 4 \
--output_dir /content/output \
--train_file /content/text.csv \
--tpu TPU ADDRESS

run_mlm.py: error: ambiguous option: --tpu could match --tpu_num_cores, --tpu_metrics_debug, --tpu_name, --tpu_zone

Apparently, there is no option for passing the tpu_address.

I’m not sure how to get tpu_name, so that I can pass that either?

Any help is appreciated.

1 Like

Hi,
I think you need to run this script (on Colab, with TPU session active):

!python xla_spawn.py --num_cores 8 run_mlm.py --model_name_or_path bert-base-cased \ --tpu_num_cores 8 \ --all other params

Hi Gennaro,

Appreciate your response. Where exactly can I get this xla_spawn.py script from? Thanks!

EDIT: I actually managed to get the script to run just by passing the --tpu_num_cores 8 parameter to the script. Thank you so much for your help!

Hi @timodim,
here the xla_spawn script. Glad it is working!

I am trying to train a TFDistilBertForTokenClassification with TPU + Colab. I work around the problem of readying a TF dataset on GCS, but the model still fails to train. Prominently:

(0) INVALID_ARGUMENT: {{function_node __inference_train_function_26618}} Detected unsupported operations when trying to compile graph tf_distil_bert_for_token_classification_cond_true_22902 on XLA_TPU_JIT: PrintV2 (No registered ‘PrintV2’ OpKernel for XLA_TPU_JIT devices compatible with node {{node tf_distil_bert_for_token_classification/cond/PrintV2}}){{node tf_distil_bert_for_token_classification/cond/PrintV2}}
One approach is to outside compile the unsupported ops to run on CPUs by enabling soft placement tf.config.set_soft_device_placement(True). This has a potential performance penalty.

It seems to be using some unsupported op TPU doesn’t like. Any idea how to hack this?