How to accumulate when examples per batch is not fixed

I am having trouble determining how to properly add up loss/gradients in the situation where each per device batch has a different number of “examples.”

My examples are actually atomic systems, and a batch of examples will contain atoms from different systems. Unlike tokens in NLP which are padded to the same length, my model operates on the atoms on long form and keeps track of different systems by an indicator.

Given this, I created a model wrapper that returns losses on the sum basis over atoms - thus batches with more atoms have larger losses on average. This is desired because I would like a system with more atoms to count for more during training, which is an atom level task.

In my training loop I am returning that sum loss, eg:

loss = compute_loss(model, per_device_batch) # < ---- this is not normalized
atoms_in_batch = get_atoms_in_batch(per_device_batch)

The model and the data loader are wrapped with accelerate. I imaging I can gather the atoms_in_batch to normalize loss accumulated over devices and batches, but am not sure how as we are expected to call accelerator.backward within the loop.

Thanks for any assistance.

1 Like