In the distributed evaluation section of the docs, it is said that one should use accelerator.gather
to collect the data from the various devices.
My question: do you also need to use accelerator.gather
when collecting the losses from the various cores? I defined the loss calculation as follows:
# Evaluate at the end of the epoch (distributed evaluation as we have 8 TPU cores)
model.eval()
validation_losses = []
for batch in val_dataloader:
with torch.no_grad():
outputs = model(**batch)
loss = outputs.loss
# We gather the loss from the 8 TPU cores to have them all.
validation_losses.append(accelerator.gather(loss))
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", sum(validation_losses) / len(validation_losses))
However, this fails:
File "/usr/local/lib/python3.7/dist-packages/torch_xla/core/xla_model.py", line 916, in mesh_reduce
return reduce_fn(xldata) if xldata else cpu_data
File "/usr/local/lib/python3.7/dist-packages/accelerate/accelerator.py", line 290, in gather
return gather(tensor)
File "/usr/local/lib/python3.7/dist-packages/accelerate/utils.py", line 144, in _tpu_gather
return xm.mesh_reduce(name, tensor, torch.cat)
File "/usr/local/lib/python3.7/dist-packages/torch_xla/core/xla_model.py", line 916, in mesh_reduce
return reduce_fn(xldata) if xldata else cpu_data
RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated
RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated
File "/usr/local/lib/python3.7/dist-packages/accelerate/utils.py", line 171, in gather
return _tpu_gather(tensor, name="accelerate.utils.gather")
File "/usr/local/lib/python3.7/dist-packages/torch_xla/core/xla_model.py", line 916, in mesh_reduce
return reduce_fn(xldata) if xldata else cpu_data
RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated
RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated
File "/usr/local/lib/python3.7/dist-packages/accelerate/utils.py", line 144, in _tpu_gather
return xm.mesh_reduce(name, tensor, torch.cat)
File "/usr/local/lib/python3.7/dist-packages/torch_xla/core/xla_model.py", line 916, in mesh_reduce
return reduce_fn(xldata) if xldata else cpu_data
RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated
Because the loss is a zero-dimensional tensor instead of 2D.