I’m attempting to fine-tune a Llama-2-7b-hf model to be predict quantities (in g or ml) of ingredients in a recipe.
For example:
Inputs to model:
Name of dish or drink: harissa chicken traybake
List of ingredients:
[chicken legs, harissa paste, garlic, lemon, cherry tomatoes, new potatoes, Kalamata olives, olive oil]
Ground truth Output
Ingredients and quantities:
[chicken legs, 200, harissa paste, 15, garlic, 15, lemon, 17.5, cherry tomatoes, 100, new potatoes, 87.5, Kalamata olives, 12.5, olive oil, 7.5]
Fine tuning with the default SFTTrainer parameters results in something looking like this:
[chicken legs, 100, harissa paste, 10, garlic, 10, lemon, 10, cherry tomatoes, 100, new potatoes, 100, Kalamata olives, 10, olive oil, 10]
I believe I need a custom loss function that attempts to minimise the mean squared error between the quantity values. I have subclassed the SFTTrainer as below, but I get a gradient of zero. Is it possible to create a custom loss function in this way? I’m fairly new to this so any help would be very much appreciated!
class CustomSFTTrainer(SFTTrainer):
def __init__(self, *args, **kwargs):
super(CustomSFTTrainer, self).__init__(*args, **kwargs)
def compute_loss(self, model, inputs, return_outputs=False):
# get label and prediction tokens
labels = inputs.get("labels")
outputs = model(**inputs)
predictions = outputs.get("logits")
# decode predictions and labels
predicted_token_ids = torch.argmax(predictions, dim=-1)
decoded_predictions = [tokenizer.decode(p.tolist()) for p in predicted_token_ids]
decoded_labels = [tokenizer.decode(l.tolist()) for l in labels]
# function to output quantities to a list
predicted_quantities, actual_quantities = quantities(decoded_predictions, decoded_labels)
predicted_tensor = torch.tensor(predicted_quantities, device=model.device)
actual_tensor = torch.tensor(actual_quantities, device=model.device)
predicted_tensor.requires_grad_()
# Compute MSE loss
loss_function = nn.MSELoss()
loss = loss_function(predicted_tensor, actual_tensor)
return (loss, outputs) if return_outputs else loss