Accessing labels in the compute_metrics function

Hi everyone,

I’m fine-tuning an amd/AMD-Llama-135, a small llama2 based model. I have a dataset of short sequences separated as “input” and “output”. Here’s how I tokenize my dataset:

def tokenize(examples):
    input_texts = examples["input"]
    output_texts = examples["output"]

    modified_input_texts = []
    modified_output_texts = []
    for input_text, output_text in zip(input_texts, output_texts):
        modified_input_texts.append(f"{input_text.strip()}</s></s>")
        modified_output_texts.append(f"<s>{output_text.strip()}</s>")

    tokenized = tokenizer(modified_input_texts, max_length=64, padding="max_length", truncation=True)
    labels = tokenizer(modified_output_texts, max_length=64, padding="max_length", truncation=True)
    tokenized["labels"] = labels["input_ids"]
    return tokenized

After tokenizing my dataset, I can decode the labels as expected:

tokenizer.decode(tokenized_dataset["test"][0]["labels"])

I’m launching a training and have a custom compute_metrics function and before that I have a preprocess_logits_for_metrics function, here’s how they look like:

exact_match = load("exact_match")

def compute_metrics(eval_pred):
    labels = eval_pred.label_ids
    pred_ids = eval_pred.predictions
    labels[labels == -100] = tokenizer.pad_token_id
    pred_ids[0][pred_ids[0] == -100] = tokenizer.pad_token_id
    
    decoded_preds = tokenizer.batch_decode(pred_ids[0], skip_special_tokens=False)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=False)

    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [label.strip() for label in decoded_labels]
    
    
    def extract_between_tokens(text):
        match = re.search(r'<s><s>(.*?)<\/s>', text)
        if match:
            extracted = match.group(1)
            return text
        else:
            return text

    preds_between_tokens = [extract_between_tokens(pred) for pred in decoded_preds]


    wrong_predictions = [
        (pred, label) for pred, label in zip(preds_between_tokens, decoded_labels) if pred != label
    ]
    for pred, label in wrong_predictions:
        print(f"Prediction: {pred}")
        print(f"Ground Truth: {label}")
        print("-------")

    results = exact_match.compute(predictions=preds_between_tokens, references=decoded_labels)
    
    return {"exact_match": results["exact_match"]}


def preprocess_logits_for_metrics(logits, labels): 
    
    labels[labels == -100] = tokenizer.pad_token_id
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=False)
    print(f"Dec labels preprocess: {decoded_labels}")
    pred_ids = torch.argmax(logits, dim=-1) 
    	
    return pred_ids, labels

When I print out the decoded labels in the preprocess logits function (doing this only for debug purposes), which is before the compute_metrics. I’m seeing the inputs rather than the labels. The same happens in the compute_metrics function.

I searched a lot for this particular issues but could not find any relevant topics. Can anyone share any ideas on what might be the reason for this?

Thanks in advance!

1 Like