Track number of tokens seen during training in wandb with Trainer API

I have a training loop that uses the Trainer API and reports the default metrics to weights and biases.

It is common to report the number of tokens seen during model training (not the number of steps or examples seen) in order to study scaling behavior. However, I don’t see a straightforward way to do this.

Is there a simple way to track the number of tokens my Trainer has seen during training, and report this to wandb? I can see this might require using a custom WandbCallBack but it isn’t clear where, if at all, the number of tokens is even tracked in the Trainer state.

Thanks!