Supervised Fine-tuning Trainer - Custom Loss Function

I’m attempting to fine-tune a Llama-2-7b-hf model to be predict quantities (in g or ml) of ingredients in a recipe.

For example:

Inputs to model:
Name of dish or drink: harissa chicken traybake
List of ingredients:
[chicken legs, harissa paste, garlic, lemon, cherry tomatoes, new potatoes, Kalamata olives, olive oil]

Ground truth Output
Ingredients and quantities:
[chicken legs, 200, harissa paste, 15, garlic, 15, lemon, 17.5, cherry tomatoes, 100, new potatoes, 87.5, Kalamata olives, 12.5, olive oil, 7.5]

Fine tuning with the default SFTTrainer parameters results in something looking like this:
[chicken legs, 100, harissa paste, 10, garlic, 10, lemon, 10, cherry tomatoes, 100, new potatoes, 100, Kalamata olives, 10, olive oil, 10]

I believe I need a custom loss function that attempts to minimise the mean squared error between the quantity values. I have subclassed the SFTTrainer as below, but I get a gradient of zero. Is it possible to create a custom loss function in this way? I’m fairly new to this so any help would be very much appreciated!

class CustomSFTTrainer(SFTTrainer):
    def __init__(self, *args, **kwargs):
        super(CustomSFTTrainer, self).__init__(*args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):
       # get label and prediction tokens
        labels = inputs.get("labels")
        outputs = model(**inputs)
        predictions = outputs.get("logits")
        # decode predictions and labels
        predicted_token_ids = torch.argmax(predictions, dim=-1)
        decoded_predictions = [tokenizer.decode(p.tolist()) for p in predicted_token_ids]
        decoded_labels = [tokenizer.decode(l.tolist()) for l in labels]

        # function to output quantities to a list       
        predicted_quantities, actual_quantities = quantities(decoded_predictions, decoded_labels)
        predicted_tensor = torch.tensor(predicted_quantities, device=model.device)
        actual_tensor = torch.tensor(actual_quantities, device=model.device)
        # Compute MSE loss
        loss_function = nn.MSELoss()
        loss = loss_function(predicted_tensor, actual_tensor)
        return (loss, outputs) if return_outputs else loss

Did you solve it?

Yes, the issue came from using the argmax function to select the most probable token during decoding. The argmax operation is non-differentiable, meaning it doesn’t allow gradients to flow through it during backpropagation because it outputs discrete indices. This prevents the optimization algorithm from adjusting model parameters based on the loss gradient with respect to the input predictions.

To resolve this and enable a differentiable pipeline that allows for the computation and use of gradients in backpropagation, I switched from a hard selection mechanism (argmax) to a soft selection mechanism (softmax). Unlike argmax, the softmax function provides a probability distribution over possible output tokens, which is differentiable with respect to the input predictions.

Here’s the specific change:

Original line using argmax (non-differentiable selection):

predicted_token_ids = torch.argmax(predictions, dim=-1)

Modified to use a custom softmax function (differentiable selection):

def softmax_selection(predictions, temperature=1.0):
    Apply softmax to model predictions and sample a token based on the resulting probabilities.

        predictions (torch.Tensor): The tensor containing the raw predictions from the model.
        temperature (float): Temperature parameter to adjust the sharpness of the probability distribution.
                              A lower temperature makes the distribution sharper.

        torch.Tensor: Tensor containing the selected token IDs.
    # Apply softmax with temperature
    probs = F.softmax(predictions / temperature, dim=-1)

    # Sampling a token based on the probabilities
    sampled_tokens = torch.multinomial(probs, num_samples=1)

    return sampled_tokens

# Select token IDs based on softmax probabilities
selected_token_ids = softmax_selection(predictions, temperature=temperature)
1 Like