Help with custom loss function

Hey everyone!

I’m coding a custom loss function with transformers using a pytorch loop. I need to combine the crossentropy from the trainset with the crossentropy from another labeled set, which was artificially generated (inferred from another model).

Loss = train_loss + artificial_loss.

I usually have way more artificial data than train data, so in a single train loop, I must iterate both sets, and iterate N times more through my artificial data set (with N being the ratio between the sets).

The problem is: when printing my loss, the combined loss and the labeled (train set) loss decreases, but the unlabeled (artificial set) loss always increases to unusually large values. What am I doing wrong?

Training Log Example:
https://drive.google.com/file/d/1-tA9Mn0-yc4RHweP7wO6-H7ddUKc94/view?usp=sharing

Training code:

def train(
    model,
    train_dataloader,
    optimizer,
    scheduler,
    val_dataloader=None,
    evaluate_during_training=False,
    is_student=False,
    unlabeled_dataloader=None,
    unl_to_label_batch_ratio=None,
):
    progress_bar = tqdm(range(CFG.num_train_epochs * len(train_dataloader)))
    log("Start training...\n")
    for epoch_i in range(CFG.num_train_epochs):
        if is_student:
            log(
                f"{'Epoch':^7} | {'Labeled Batch':^14} | {'Unlabeled Batch':^16} | "
                f"{'Train Loss':^11} | {'Labeled Loss':^13} | "
                f"{'Unlabeled Loss':^15} | {'Val Loss':^10} | {'Val Acc':^9} | {'Elapsed':^9}"
            )
            log("-"*130)
        else:
            log(
                f"{'Epoch':^7} | {'Train Batch':^12} | "
                f"{'Train Loss':^12} | {'Val Loss':^10} | {'Val Acc':^9} | {'Elapsed':^9}"
            )
            log("-"*80)

        # measure the elapsed time of each epoch
        t0_epoch, t0_batch = time.time(), time.time()

        # reset tracking variables at the beginning of each epoch
        total_loss, batch_loss, batch_unl_loss, batch_lab_loss, batch_counts, = 0, 0, 0, 0, 0

        # train loop
        model.train()
        loss_fn = nn.CrossEntropyLoss()
        for step, batch in enumerate(train_dataloader):
            batch_counts +=1
            batch_inputs = {k: v.to(CFG.device) for k, v in batch.items()}

            optimizer.zero_grad()
            output = model(**batch_inputs)

            # if model is student, train with the noised data aswell
            if is_student:
                text_col = "text_augmented" if CFG.augmented_data else "text"
                unl_logits = []
                unl_labels = []
                for i in range(unl_to_label_batch_ratio):
                    unl_batch = next(iter(unlabeled_dataloader))
                    unl_inputs = tokenizer.batch_encode_plus(
                        unl_batch[text_col],
                        padding="max_length",
                        truncation=True,
                        max_length=CFG.max_seq_len,
                        return_tensors="pt"
                    )
                    unl_inputs["labels"] = unl_batch["labels"].clone().detach()

                    unl_batch_inputs = {k: v.to(CFG.device) for k, v in unl_inputs.items()}
                    unl_output = model(**unl_batch_inputs)

                    unl_logits.append(unl_output.logits)
                    unl_labels.append(unl_inputs["labels"])

                unl_labels = torch.cat([t for t in unl_labels]).to(CFG.device)
                unl_logits = torch.cat([t for t in unl_logits])

                unl_loss = loss_fn(unl_logits, unl_labels)
                lab_loss = output.loss
                loss = lab_loss + unl_loss

                batch_lab_loss += lab_loss.item()
                batch_unl_loss += unl_loss.item()

            else:
                loss = output.loss

            batch_loss += loss.item()
            total_loss += loss.item()

            loss.backward()

            if CFG.clip_grad:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()
            scheduler.step()
            progress_bar.update(1)

            if (step % 100 == 0 and step != 0) or (step == len(train_dataloader) - 1):
                time_elapsed = time.time() - t0_batch

                # Print training results
                if is_student:
                    log(
                        f"{epoch_i + 1:^7} | {step:^14} | {(step*unl_to_label_batch_ratio):^16} | "
                        f"{batch_loss / batch_counts:^11.6f} | "
                        f"{batch_lab_loss / batch_counts:^15.6f} | "
                        f"{batch_unl_loss / batch_counts :^13.6f} | "
                        f"{'-':^10} | {'-':^9} | {time_elapsed:^9.2f}"
                    )

                else:
                    log(
                        f"{epoch_i + 1:^7} | {step:^12} | {batch_loss / batch_counts:^12.6f} | "
                        f"{'-':^10} | {'-':^9} | {time_elapsed:^9.2f}"
                    )

                batch_loss, batch_lab_loss, batch_unl_los, batch_counts = 0, 0, 0, 0
                t0_batch = time.time()

        # Calculate the average loss over the entire training data
        avg_train_loss = total_loss / len(train_dataloader)
        if evaluate_during_training:
            val_loss, val_accuracy = evaluate(model, val_dataloader)
            time_elapsed = time.time() - t0_epoch

            if is_student:
                log("-"*130)
                log(
                    f"{epoch_i + 1:^7} | {'-':^14} | {'-':^16} | {avg_train_loss:^11.6f} | "
                    f"{'-':^15} | {'-':^13}| {val_loss:^10.6f} | "
                    f"{val_accuracy:^9.2f} | {time_elapsed:^9.2f}"
                )
                log("-"*130)
            else: 
                log("-"*80)
                log(
                    f"{epoch_i + 1:^7} | {'-':^12} | {avg_train_loss:^12.6f} | "
                    f"{val_loss:^10.6f} | {val_accuracy:^9.2f} | {time_elapsed:^9.2f}"
                )
                log("-"*80)
        log("\n")