Custom loss does not work

Hi,

It’s my first time training a model. I fine-tuned a T5 model to create sentences that contain a specified word. The initial training went well: the model generates coherent sentences, and about 75% of the generated sentences include the target word.

Then, I thought I could modify the loss function to add a penalty for not including the target word. This way, I could improve the precision of the target word usage.

I came up with the following code and integrated it into a custom trainer. The code is supposed to evaluate the probabilities of the target sequence (the target word may consist of more than one token) in the predicted logits and then incorporate it into the base loss:

def custom_loss(model, inputs, target_word_ids, alpha=5.0):
    # 1. Forward pass
    outputs = model(**inputs)
    logits = outputs.logits
    labels = inputs["labels"]

    # 2. Loss base de cross entropy
    loss_fct = CrossEntropyLoss(ignore_index=model.config.pad_token_id)
    base_loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))

    # 3. Calcular probabilidades
    probs = torch.softmax(logits, dim=-1)  # [batch_size, seq_len, vocab_size]
    batch_size = logits.size(0)
    target_probs = torch.zeros(batch_size, device=logits.device)

    # 4. Calcular penalización para cada secuencia en el batch
    for i in range(batch_size):
        target_seq = target_word_ids[i][target_word_ids[i] != 0]  # Eliminar padding
        if len(target_seq) > 0:
            # Tomar los primeros N tokens
            n_tokens = min(2, len(target_seq))
            tokens_to_check = target_seq[:n_tokens]

            # Calcular probabilidad para cada token
            token_probs = []
            for token in tokens_to_check:
                # Obtener probabilidades para el token en todas las posiciones
                probs_for_token = probs[i, :, token]
                """ Prob tiene la probabilidad suave de que la secuencia de tokens aparezca en toda la frase, le da más
                    importancia a los valores altos.
                    Se podría coger solo la probabilidad de la posición más probable, pero entonces los gradientes solo
                    se propagarían hacia los pesos que generaron esa posición de la secuencia, con esto se supone que se
                    propaga más "suave" (caliente/frío) hacia todos los pesos
                """
                prob = torch.logsumexp(probs_for_token * 5.0, dim=0) / 5.0
                token_probs.append(prob)

            # Promedio de las probabilidades de los tokens
            target_probs[i] = sum(token_probs) / len(token_probs)
    # 5. Calcular penalización final, usar log para que una probabilidad alta penalice poco
    penalty = -torch.log(target_probs + 1e-10).mean()

    # 6. Combinar losses
    total_loss = base_loss + alpha * penalty

    outputs.loss = total_loss
    return total_loss, outputs

class CustomTrainer(Seq2SeqTrainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        if "target_word" in inputs:  # Para training y validation loss
            target_word = inputs.pop("target_word")
            loss, outputs = custom_loss(model, inputs, target_word, alpha=ALPHA_LOSS)
            return (loss, outputs) if return_outputs else loss
        else:  # Para generación
            outputs = model(**inputs)
            loss = outputs.loss
            return (loss, outputs) if return_outputs else loss

    def prediction_step(
        self,
        model,
        inputs,
        prediction_loss_only,
        ignore_keys=None,
    ):
        inputs_copy = {}
        expected_keys = ["input_ids", "attention_mask"]

        for key in expected_keys:
            if key in inputs:
                inputs_copy[key] = inputs[key]

        return super().prediction_step(
            model,
            inputs_copy,
            prediction_loss_only,
            ignore_keys=ignore_keys,
        )

I trained a model using different combinations of Alpha, but it doesn’t seem to improve the target word accuracy.

Does anyone know what might be wrong? Could it be that the gradients are not propagating correctly? Or maybe the loss function doesn’t make sense at all?

Any help would be appreciated.

1 Like

Please show your error clearly.
I can’t help you because your code too long and you should mention about your error clearly.

1 Like