Supervised Fine-tuning Trainer - Custom Loss Function

I’m attempting to fine-tune a Llama-2-7b-hf model to be predict quantities (in g or ml) of ingredients in a recipe.

For example:

Inputs to model:
Name of dish or drink: harissa chicken traybake
List of ingredients:
[chicken legs, harissa paste, garlic, lemon, cherry tomatoes, new potatoes, Kalamata olives, olive oil]

Ground truth Output
Ingredients and quantities:
[chicken legs, 200, harissa paste, 15, garlic, 15, lemon, 17.5, cherry tomatoes, 100, new potatoes, 87.5, Kalamata olives, 12.5, olive oil, 7.5]

Fine tuning with the default SFTTrainer parameters results in something looking like this:
[chicken legs, 100, harissa paste, 10, garlic, 10, lemon, 10, cherry tomatoes, 100, new potatoes, 100, Kalamata olives, 10, olive oil, 10]

I believe I need a custom loss function that attempts to minimise the mean squared error between the quantity values. I have subclassed the SFTTrainer as below, but I get a gradient of zero. Is it possible to create a custom loss function in this way? I’m fairly new to this so any help would be very much appreciated!

class CustomSFTTrainer(SFTTrainer):
    def __init__(self, *args, **kwargs):
        super(CustomSFTTrainer, self).__init__(*args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):
       # get label and prediction tokens
        labels = inputs.get("labels")
        outputs = model(**inputs)
        predictions = outputs.get("logits")
        # decode predictions and labels
        predicted_token_ids = torch.argmax(predictions, dim=-1)
        decoded_predictions = [tokenizer.decode(p.tolist()) for p in predicted_token_ids]
        decoded_labels = [tokenizer.decode(l.tolist()) for l in labels]

        # function to output quantities to a list       
        predicted_quantities, actual_quantities = quantities(decoded_predictions, decoded_labels)
        predicted_tensor = torch.tensor(predicted_quantities, device=model.device)
        actual_tensor = torch.tensor(actual_quantities, device=model.device)
        # Compute MSE loss
        loss_function = nn.MSELoss()
        loss = loss_function(predicted_tensor, actual_tensor)
        return (loss, outputs) if return_outputs else loss