T5 Model Evaluation on Generation

I’m using a T5ForConditionalGeneration model for natural language inference. Constructing the prompt with a few examples to test out the capabilities zero shot on a new dataset (constructing the prompt with "mnli: … ") everything works fine.

Then I try using the evaluate method on the Trainer, and the model instead of generating “contradiction” (for example), returns "contradiction contradiction … " infinite times until max sequence length.

This is the function I use for mapping the dataset:

def tokenize_function(example, tokenizer):
    prompts = generate_batch_prompts_mnli(example)
    l = ["entailment", "neutral", "contraddiction"]

    # Tokenize the premise (input) and label
    inputs = tokenizer(prompts, padding='max_length', truncation=True, max_length=128)
    labels = tokenizer([l[i] for i in example["label"]], padding="max_length", truncation=True)

    # Return a dictionary containing input and label tokens
    return {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "labels": labels["input_ids"],
    }

Then the compute metrics I pass to the Trainer is this:

def compute_metrics(eval_pred, transform, metric):
    """Compute the metrics.

    Args:
        eval_pred (EvalPrediction): the predictions and labels.
        transform (function): the function to transform the logits and labels.
        metric (datasets.Metric): the metric.

    Returns:
        dict: the computed metrics.

    """
    pred, labels = transform(eval_pred) 
    return metric.compute(predictions=pred, references=labels)

def eval_pred_transform_accuracy(eval_pred, tokenizer):
    """Transform the logits and labels to compute the accuracy.

    Args:
        eval_pred (EvalPrediction): the predictions and labels.
        tokenizer (transformers.PreTrainedTokenizer): the tokenizer.

    Returns:
        tuple: predictions and labels.

    """
    pred_ids = eval_pred.predictions[0]
    labels = eval_pred.label_ids
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
    return pred_str, label_str

def preprocess_logits_argmax(logits, labels):
    """Pre-process the logits and labels to compute the metrics.

    Args:
        logits (list of torch.Tensor): the logits and the labels logits.
        labels (torch.Tensor): the labels.

    Returns:
        tuple: predictions and labels.

    """
    pred_ids = logits[0].argmax(dim=-1)
    
    return pred_ids, labels

The code can be found at this GitHub repo.
What is going on?