Plotting separate loss curves for different datasets

I have multiple datasets mixed together using datasets.interleave_datasets().

In addition to the overall training loss, I would like to plot one loss curve per dataset in wandb so I can get a sense of how it is overfitting to each dataset. In this setup it’s totally possible that a particular dataset might not have produced any examples in between two reporting points, and that is totally fine.

I tried to map() my datasets before interleaving them to add a ‘source’ identifier to them, but in Trainer.compute_loss() the inputs dict doesn’t contain this ‘source’ key. Reading the documentation for Trainer, it says

If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed.

Are there standard ways of achieving what I want, or is there a way to keep the key around?