I have multiple datasets mixed together using datasets.interleave_datasets().
In addition to the overall training loss, I would like to plot one loss curve per dataset in wandb so I can get a sense of how it is overfitting to each dataset. In this setup it’s totally possible that a particular dataset might not have produced any examples in between two reporting points, and that is totally fine.
I tried to map() my datasets before interleaving them to add a ‘source’ identifier to them, but in
Trainer.compute_loss() the inputs dict doesn’t contain this ‘source’ key. Reading the documentation for Trainer, it says
If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed.
Are there standard ways of achieving what I want, or is there a way to keep the key around?