TPU trainer with multi-core

Hello, I’d like to train a SQuAD model by exploiting the 8-cores TPU offered by google colab.
I followed the tutorial for fine-tuning using the Trainer API by @sgugger
In the tutorial it states that the Trainer is made for supporting TPUs out of the box, so the only thing I’ve added to my code are the XLA library install/imports, then I’ve wrapped the .train() function inside _mp_fn() and gave it to xmp.spawn() , specifying 8 as num of cores.

But I get the following error:

Exception in device=TPU:0: Cannot replicate if number of devices (1) is different from 8

After a research, I found out that this error is raised when the XLA device is called outside the spawn process, but I have no such calls, so maybe it’s wrapped inside one of the Huggingface functions, but how do I disable it?
The code works just fine with just 1 core, but I’d like to exploit all 8.
This is my notebook:

You need to define the TrainingArguments inside you multiprocess function, and you should also define your model in that function:

def _mp_fn(rank, flags):
    data_collator = default_data_collator
    model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)
    model_name = model_checkpoint.split("/")[-1]
    args = TrainingArguments(
        evaluation_strategy = "epoch",
    trainer = Trainer(


This should work properly.

Hi. I tried tpu_num_cores=8 to the TrainingArguments class. I also run codes on Google colab. Then, I encounter the following error message.

2022-04-19 15:51:14.312185: E tensorflow/compiler/xla/xla_client/] *** Begin stack trace ***
2022-04-19 15:51:14.312193: E tensorflow/compiler/xla/xla_client/] 	tensorflow::CurrentStackTrace()
2022-04-19 15:51:14.312202: E tensorflow/compiler/xla/xla_client/] 	xla::util::ReportComputationError(tensorflow::Status const&, absl::lts_20211102::Span<xla::XlaComputation const* const>, absl::lts_20211102::Span<xla::Shape const* const>)
2022-04-19 15:51:14.312211: E tensorflow/compiler/xla/xla_client/] 	xla::util::ShapeHash(xla::Shape const&)
2022-04-19 15:51:14.312219: E tensorflow/compiler/xla/xla_client/] 	xla::XrtComputationClient::ExecuteComputation(xla::ComputationClient::Computation const&, absl::lts_20211102::Span<std::shared_ptr<xla::ComputationClient::Data> const>, std::string const&, xla::ComputationClient::ExecuteComputationOptions const&)
2022-04-19 15:51:14.312227: E tensorflow/compiler/xla/xla_client/] 	
2022-04-19 15:51:14.312234: E tensorflow/compiler/xla/xla_client/] 	xla::util::MultiWait::Complete(std::function<void ()> const&)
2022-04-19 15:51:14.312240: E tensorflow/compiler/xla/xla_client/] 	
2022-04-19 15:51:14.312246: E tensorflow/compiler/xla/xla_client/] 	
2022-04-19 15:51:14.312253: E tensorflow/compiler/xla/xla_client/] 	
2022-04-19 15:51:14.312259: E tensorflow/compiler/xla/xla_client/] 	clone
2022-04-19 15:51:14.312265: E tensorflow/compiler/xla/xla_client/] *** End stack trace ***
2022-04-19 15:51:14.312271: E tensorflow/compiler/xla/xla_client/] 
2022-04-19 15:51:14.312277: E tensorflow/compiler/xla/xla_client/] Status: INTERNAL: From /job:tpu_worker/replica:0/task:0:
2022-04-19 15:51:14.312289: E tensorflow/compiler/xla/xla_client/] 2 root error(s) found.
2022-04-19 15:51:14.312300: E tensorflow/compiler/xla/xla_client/]   (0) INTERNAL: stream did not block host until done; was already in an error state
2022-04-19 15:51:14.312310: E tensorflow/compiler/xla/xla_client/] 	 [[{{node XRTExecute}}]]
2022-04-19 15:51:14.312320: E tensorflow/compiler/xla/xla_client/] 	 [[XRTExecute_G15]]
2022-04-19 15:51:14.312329: E tensorflow/compiler/xla/xla_client/]   (1) INTERNAL: stream did not block host until done; was already in an error state
2022-04-19 15:51:14.312338: E tensorflow/compiler/xla/xla_client/] 	 [[{{node XRTExecute}}]]
2022-04-19 15:51:14.312348: E tensorflow/compiler/xla/xla_client/] 0 successful operations.
2022-04-19 15:51:14.312364: E tensorflow/compiler/xla/xla_client/] 0 derived errors ignored.
  0% 1/21250 [00:00<2:20:22,  2.52it/s]Traceback (most recent call last):
  File "", line 71, in <module>
  File "/content/common-crawal-preprocess/model_trainings/procedure_torch/", line 69, in main
  File "/usr/local/lib/python3.7/dist-packages/transformers/", line 1306, in train
    for step, inputs in enumerate(epoch_iterator):
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/", line 34, in __next__
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/", line 46, in next
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/core/", line 787, in mark_step
    wait=xu.getenv_as('XLA_SYNC_WAIT', bool, False))
RuntimeError: INTERNAL: From /job:tpu_worker/replica:0/task:0:
2 root error(s) found.
  (0) INTERNAL: stream did not block host until done; was already in an error state
	 [[{{node XRTExecute}}]]
  (1) INTERNAL: stream did not block host until done; was already in an error state
	 [[{{node XRTExecute}}]]
0 successful operations.
0 derived errors ignored.

What does the message mean?
Be informed that the error message does not appear when I do not specify any to tpu_num_cores.

Thank you! It works properly now.

@kensuke-mi I couldn’t reproduce your error, did you try to restart the runtime?

@Kioto97 You’re right. I was supposed to restart the Colab instance. After the initialization, the error went away. So, is your computation speed much faster with tpu_num_cores=8? The speed becomes slower in my configuration. I tested it with per_device_train_batch_size=16.

To be honest I haven’t managed to run it yet, on colab, even with a lower batch size, I get an error:

process 4 terminated with signal SIGKILL

Which I suspect is due to an insufficient amount of RAM. Did you change anything else ?