Track number of tokens seen during training in wandb with Trainer API

I have a training loop that uses the Trainer API and reports the default metrics to weights and biases.

It is common to report the number of tokens seen during model training (not the number of steps or examples seen) in order to study scaling behavior. However, I don’t see a straightforward way to do this.

Is there a simple way to track the number of tokens my Trainer has seen during training, and report this to wandb? I can see this might require using a custom WandbCallBack but it isn’t clear where, if at all, the number of tokens is even tracked in the Trainer state.

Thanks!

1 Like

Did you find a way?
Hard to believe there’s nobody who wants this feature.

No. I don’t see an easy way to implement this without refactoring the entire Trainer.__init__() and Trainer._inner_training_loop() methods, which seems like a total mess.

I think transformers should instead include this behavior by default. I created an issue on their Github page and am willing to take a stab at an implementation if they are willing to give it the green light and provide some guidance on design. Please upvote the issue and comment there!

1 Like