How to pass multiple datasets into Trainer for Knowledge distillation in NMT

I am trying apply knowledge distillation for domain adaptation problem in NMT, I understood that to create my custom loss function I need to subclass the Trainer class and override the compute_loss function. As I’m following Sequence level distillation, I am required to pass data from 6 different domains, calculate the individual loss per domain then calculate the global loss. It will wonderful if someone could point me to any resource which shows passing multiple datasets into the Trainer Class

Did you get any idea for this?@ velmen

Actually, i had stop using the trainer class as it is very hard to customize towards the task I described. I am currently working on the solution by wiritng a custom pytorch loop to train the models.

I have done some research on this, but stuck here, if you don’t mind can we together work on it?