Hi everyone,
I’m fine-tuning an amd/AMD-Llama-135, a small llama2 based model. I have a dataset of short sequences separated as “input” and “output”. Here’s how I tokenize my dataset:
def tokenize(examples):
input_texts = examples["input"]
output_texts = examples["output"]
modified_input_texts = []
modified_output_texts = []
for input_text, output_text in zip(input_texts, output_texts):
modified_input_texts.append(f"{input_text.strip()}</s></s>")
modified_output_texts.append(f"<s>{output_text.strip()}</s>")
tokenized = tokenizer(modified_input_texts, max_length=64, padding="max_length", truncation=True)
labels = tokenizer(modified_output_texts, max_length=64, padding="max_length", truncation=True)
tokenized["labels"] = labels["input_ids"]
return tokenized
After tokenizing my dataset, I can decode the labels as expected:
tokenizer.decode(tokenized_dataset["test"][0]["labels"])
I’m launching a training and have a custom compute_metrics function and before that I have a preprocess_logits_for_metrics function, here’s how they look like:
exact_match = load("exact_match")
def compute_metrics(eval_pred):
labels = eval_pred.label_ids
pred_ids = eval_pred.predictions
labels[labels == -100] = tokenizer.pad_token_id
pred_ids[0][pred_ids[0] == -100] = tokenizer.pad_token_id
decoded_preds = tokenizer.batch_decode(pred_ids[0], skip_special_tokens=False)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=False)
decoded_preds = [pred.strip() for pred in decoded_preds]
decoded_labels = [label.strip() for label in decoded_labels]
def extract_between_tokens(text):
match = re.search(r'<s><s>(.*?)<\/s>', text)
if match:
extracted = match.group(1)
return text
else:
return text
preds_between_tokens = [extract_between_tokens(pred) for pred in decoded_preds]
wrong_predictions = [
(pred, label) for pred, label in zip(preds_between_tokens, decoded_labels) if pred != label
]
for pred, label in wrong_predictions:
print(f"Prediction: {pred}")
print(f"Ground Truth: {label}")
print("-------")
results = exact_match.compute(predictions=preds_between_tokens, references=decoded_labels)
return {"exact_match": results["exact_match"]}
def preprocess_logits_for_metrics(logits, labels):
labels[labels == -100] = tokenizer.pad_token_id
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=False)
print(f"Dec labels preprocess: {decoded_labels}")
pred_ids = torch.argmax(logits, dim=-1)
return pred_ids, labels
When I print out the decoded labels in the preprocess logits function (doing this only for debug purposes), which is before the compute_metrics. I’m seeing the inputs rather than the labels. The same happens in the compute_metrics function.
I searched a lot for this particular issues but could not find any relevant topics. Can anyone share any ideas on what might be the reason for this?
Thanks in advance!