Update different parts of the model with different dataset

Hi,

Is there a convenient way to train a model using two datasets, where different parts of the model are trained for the respective dataset?

For example, I would want to train the embeddings with dataset A and train the entire model with dataset B. However, I want to switch the datasets (and which part of the model to update) dynamically during training, such as, as few batches from dataset A and then one from dataset B, not just at the end of one of the datasets. This would make it basically a Multi Task scenario.

I assume I need to write my own trainer class (and maybe also dataset?) for that. Any ideas on how to switch between the two data sets and training types efficiently? I don’t know if model.freeze has a large overhead (I would assume not as it probably just turns off the gradient store inside the parameters, but not sure).

I hope anyone has an advice. Thanks!
VC