How to use TPU for model training using example script

Hi there,

I’ve been trying to use colab TPU for training BERT and other language models using the example script provided ‘’.

I’ve used the code for loading a TPU in colab from the original T5 repo here (Google Colab)

Code provides TPU_ADDRESS from TPU_ADDRESS = tpu.get_master(), which looks something like this: grpc://

As stated in the docs (transformers/examples/tensorflow/language-modeling at master · huggingface/transformers · GitHub) TPU can be used by passing the --tpu argument, however doing this I’m getting the following error:

!python /content/transformers/examples/tensorflow/language-modeling/ \
--model_name_or_path bert-base-cased \
--validation_split_percentage 20 \
--line_by_line \
--learning_rate 2e-5 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 32 \
--num_train_epochs 4 \
--output_dir /content/output \
--train_file /content/text.csv \
--tpu TPU ADDRESS error: ambiguous option: --tpu could match --tpu_num_cores, --tpu_metrics_debug, --tpu_name, --tpu_zone

Apparently, there is no option for passing the tpu_address.

I’m not sure how to get tpu_name, so that I can pass that either?

Any help is appreciated.

I think you need to run this script (on Colab, with TPU session active):

!python --num_cores 8 --model_name_or_path bert-base-cased \ --tpu_num_cores 8 \ --all other params

Hi Gennaro,

Appreciate your response. Where exactly can I get this script from? Thanks!

EDIT: I actually managed to get the script to run just by passing the --tpu_num_cores 8 parameter to the script. Thank you so much for your help!

Hi @timodim,
here the xla_spawn script. Glad it is working!

I am trying to train a TFDistilBertForTokenClassification with TPU + Colab. I work around the problem of readying a TF dataset on GCS, but the model still fails to train. Prominently:

(0) INVALID_ARGUMENT: {{function_node __inference_train_function_26618}} Detected unsupported operations when trying to compile graph tf_distil_bert_for_token_classification_cond_true_22902 on XLA_TPU_JIT: PrintV2 (No registered ‘PrintV2’ OpKernel for XLA_TPU_JIT devices compatible with node {{node tf_distil_bert_for_token_classification/cond/PrintV2}}){{node tf_distil_bert_for_token_classification/cond/PrintV2}}
One approach is to outside compile the unsupported ops to run on CPUs by enabling soft placement tf.config.set_soft_device_placement(True). This has a potential performance penalty.

It seems to be using some unsupported op TPU doesn’t like. Any idea how to hack this?