I’m using a T5ForConditionalGeneration model for natural language inference. Constructing the prompt with a few examples to test out the capabilities zero shot on a new dataset (constructing the prompt with "mnli: … ") everything works fine.
Then I try using the evaluate method on the Trainer, and the model instead of generating “contradiction” (for example), returns "contradiction contradiction … " infinite times until max sequence length.
This is the function I use for mapping the dataset:
def tokenize_function(example, tokenizer):
prompts = generate_batch_prompts_mnli(example)
l = ["entailment", "neutral", "contraddiction"]
# Tokenize the premise (input) and label
inputs = tokenizer(prompts, padding='max_length', truncation=True, max_length=128)
labels = tokenizer([l[i] for i in example["label"]], padding="max_length", truncation=True)
# Return a dictionary containing input and label tokens
return {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"labels": labels["input_ids"],
Then the compute metrics I pass to the Trainer is this:
def compute_metrics(eval_pred, transform, metric):
"""Compute the metrics.
eval_pred (EvalPrediction): the predictions and labels.
transform (function): the function to transform the logits and labels.
metric (datasets.Metric): the metric.
dict: the computed metrics.
pred, labels = transform(eval_pred)
return metric.compute(predictions=pred, references=labels)
def eval_pred_transform_accuracy(eval_pred, tokenizer):
"""Transform the logits and labels to compute the accuracy.
eval_pred (EvalPrediction): the predictions and labels.
tokenizer (transformers.PreTrainedTokenizer): the tokenizer.
tuple: predictions and labels.
pred_ids = eval_pred.predictions[0]
labels = eval_pred.label_ids
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
return pred_str, label_str
def preprocess_logits_argmax(logits, labels):
"""Pre-process the logits and labels to compute the metrics.
logits (list of torch.Tensor): the logits and the labels logits.
labels (torch.Tensor): the labels.
tuple: predictions and labels.
pred_ids = logits[0].argmax(dim=-1)
return pred_ids, labels
The code can be found at this GitHub repo.
What is going on?