Get the loss from all TPU cores

In the distributed evaluation section of the docs, it is said that one should use accelerator.gather to collect the data from the various devices.

My question: do you also need to use accelerator.gather when collecting the losses from the various cores? I defined the loss calculation as follows:

# Evaluate at the end of the epoch (distributed evaluation as we have 8 TPU cores)
model.eval()
validation_losses = []
for batch in val_dataloader:
     with torch.no_grad():
         outputs = model(**batch)
     loss = outputs.loss

     # We gather the loss from the 8 TPU cores to have them all.
     validation_losses.append(accelerator.gather(loss))

# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", sum(validation_losses) / len(validation_losses))

However, this fails:

File "/usr/local/lib/python3.7/dist-packages/torch_xla/core/xla_model.py", line 916, in mesh_reduce
    return reduce_fn(xldata) if xldata else cpu_data
  File "/usr/local/lib/python3.7/dist-packages/accelerate/accelerator.py", line 290, in gather
    return gather(tensor)
  File "/usr/local/lib/python3.7/dist-packages/accelerate/utils.py", line 144, in _tpu_gather
    return xm.mesh_reduce(name, tensor, torch.cat)
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/core/xla_model.py", line 916, in mesh_reduce
    return reduce_fn(xldata) if xldata else cpu_data
RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated
RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated
  File "/usr/local/lib/python3.7/dist-packages/accelerate/utils.py", line 171, in gather
    return _tpu_gather(tensor, name="accelerate.utils.gather")
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/core/xla_model.py", line 916, in mesh_reduce
    return reduce_fn(xldata) if xldata else cpu_data
RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated
RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated
  File "/usr/local/lib/python3.7/dist-packages/accelerate/utils.py", line 144, in _tpu_gather
    return xm.mesh_reduce(name, tensor, torch.cat)
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/core/xla_model.py", line 916, in mesh_reduce
    return reduce_fn(xldata) if xldata else cpu_data
RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated

Because the loss is a zero-dimensional tensor instead of 2D.

You need a dimension on which to gather (by default the first one), so you should just add one:

validation_losses.append(accelerator.gather(loss[None]))

and it should work.

1 Like