By design, compute_loss()
is called with a single inputs
dict per batch. inputs
is what comes from your train_dataloader()
, one batch at a time. So only one batch is processed per call to training_step()
, which then calls compute_loss()
.
class MultiDatasetTrainer(Trainer):
def __init__(self, dataset_a=None, dataset_b=None, bs_a=32, bs_b=128, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_a = dataset_a
self.dataset_b = dataset_b
self.bs_a = bs_a
self.bs_b = bs_b
def get_train_dataloader(self):
loader_a = DataLoader(
self.dataset_a,
batch_size=self.bs_a,
shuffle=True,
collate_fn=self.data_collator,
)
loader_b = DataLoader(
self.dataset_b,
batch_size=self.bs_b,
shuffle=True,
collate_fn=self.data_collator,
)
#You'll need to make this or something like this
return CombinedLoader(loader_a, loader_b, mode="sequential")
def compute_loss(self, model, inputs, return_outputs=False):
source = inputs.pop("source")
outputs = model(**inputs)
loss = outputs.loss
# Optional: weight dataset B less
if source == "B":
loss = loss * ( a number)
return (loss, outputs) if return_outputs else loss