Problem with custom metric for custom T5 model

I have created a custom dataset and trained on it a custom T5ForConditionalGeneration model that predicts solutions to quadratic equations like this:

Input: "4*x^2 + 4*x + 1"
Output: D = 4 ^ 2 - 4 * 4 * 1 4 * 1 4 * 1 4 * 1 4 * 1 4

I need to get accuracy for this model but I get only loss when I use Trainer so I used a custom metric function (I didn’t write it but took it from a similar project):

def compute_metrics4token(eval_pred):
    batch_size = 4
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # Rouge expects a newline after each sentence
    decoded_preds =  ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels =  ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    answer_accuracy = []
    token_accuracy = []
    num_correct, num_total = 0, 0
    num_answer = 0
    number_eq = 0
    for p, l in zip(decoded_preds, decoded_labels):
        text_pred = p.split(' ')
        text_labels = l.split(' ')
        m = min(len(text_pred), len(text_labels))
        if np.array_equal(text_pred, text_labels):
            num_answer += 1
        for i, j in zip(text_pred, text_labels):
            if i == j:
                num_correct += 1
        num_total += len(text_labels)
        number_eq += 1
    token_accuracy = num_correct / num_total
    answer_accuracy = num_answer / number_eq
    result = {'token_acc': token_accuracy, 'answer_acc': answer_accuracy}
    result = {key: value for key, value in result.items()}
    for key, value in result.items():
        wandb.log({key: value})        
    return {k: round(v, 4) for k, v in result.items()}

Problem is that it doesn’t work and I don’t really understand why and what can I do to get accuracy for my model.
I get this error when I use the function:

args = Seq2SeqTrainingArguments(
    output_dir='./',
    num_train_epochs=10,
    overwrite_output_dir = True,
    evaluation_strategy = 'steps',         
    learning_rate = 1e-4,                 
    logging_steps = 100,                    
    eval_steps = 100,                      
    save_steps = 100,
    load_best_model_at_end = True,
    push_to_hub=True, 
    weight_decay = 0.01,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=4
    )

trainer = Seq2SeqTrainer(model=model, train_dataset=train_dataset, eval_dataset=eval_dataset, args=args, 
                  data_collator=data_collator, tokenizer=tokenizer, compute_metrics=compute_metrics4token)
<ipython-input-48-ff7980f6dd66> in compute_metrics4token(eval_pred)
      4     # predictions = np.argmax(logits[0])
      5     # print(predictions)
----> 6     decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
      7     # Replace -100 in the labels as we can't decode them.
      8     labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py in batch_decode(self, sequences, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
   3444             `List[str]`: The list of decoded sentences.
   3445         """
-> 3446         return [
   3447             self.decode(
   3448                 seq,

/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py in <listcomp>(.0)
   3445         """
   3446         return [
-> 3447             self.decode(
   3448                 seq,
   3449                 skip_special_tokens=skip_special_tokens,

/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py in decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
   3484         token_ids = to_py_obj(token_ids)
   3485 
-> 3486         return self._decode(
   3487             token_ids=token_ids,
   3488             skip_special_tokens=skip_special_tokens,

/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_fast.py in _decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
    547         if isinstance(token_ids, int):
    548             token_ids = [token_ids]
--> 549         text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
    550 
    551         clean_up_tokenization_spaces = (

TypeError: argument 'ids': 'list' object cannot be interpreted as an integer

When I print out predictions I get a tuple:

(array([[[-32.777344, -34.593437, -36.065685, ..., -34.78577 ,
         -34.77546 , -34.061115],
        [-58.633934, -32.23472 , -31.735909, ..., -40.335655,
         -40.28701 , -37.208904],
        [-56.650974, -33.564095, -34.409576, ..., -36.94467 ,
         -43.246735, -37.469246],
        ...,
        [-56.62741 , -24.561722, -34.11228 , ..., -35.34798 ,
         -42.287125, -38.889412],
        [-56.632545, -24.470266, -34.0792  , ..., -35.313175,
         -42.235626, -38.891712],
        [-56.687027, -24.391508, -34.12526 , ..., -35.30828 ,
         -42.204193, -38.88395 ]],

       [[-29.79866 , -32.22621 , -32.689865, ..., -32.106445,
         -31.46681 , -31.706667],
        [-62.101192, -33.327423, -30.900173, ..., -38.046883,
         -42.26345 , -38.97748 ],
        [-54.726807, -29.13115 , -30.294558, ..., -28.370876,
         -41.23722 , -37.91609 ],
        ...,
        [-57.279373, -23.954525, -34.066246, ..., -35.047447,
         -41.599922, -38.489853],
        [-57.31298 , -23.879845, -34.0837  , ..., -35.03614 ,
         -41.557755, -38.530064],
        [-57.39132 , -23.831306, -34.120094, ..., -35.039547,
         -41.525337, -38.55728 ]],

       [[-29.858566, -32.452713, -34.05892 , ..., -33.93065 ,
         -32.109177, -32.874695],
        [-61.375793, -33.656853, -32.95248 , ..., -42.28087 ,
         -42.637173, -39.21142 ],
        [-58.43721 , -32.496166, -36.44046 , ..., -39.33864 ,
         -42.139664, -38.695328],
        ...,
        [-59.654663, -24.117435, -34.266438, ..., -35.734142,
         -40.55384 , -38.467537],
        [-38.54418 , -18.533113, -29.775307, ..., -26.856483,
         -33.07976 , -29.934727],
        [-27.716005, -14.610603, -23.752686, ..., -21.140053,
         -26.855148, -24.429493]],

       ...,

       [[-33.252697, -34.72487 , -36.395184, ..., -36.87368 ,
         -35.207897, -34.468285],
        [-59.911736, -32.730076, -32.622803, ..., -43.382267,
         -42.25615 , -38.35135 ],
        [-54.982887, -31.847572, -32.773827, ..., -38.500675,
         -43.97969 , -37.41088 ],
        ...,
        [-56.896988, -23.213766, -34.04734 , ..., -35.88832 ,
         -42.176086, -38.953568],
        [-56.994152, -23.141619, -34.054848, ..., -35.875816,
         -42.176453, -38.97729 ],
        [-57.076714, -23.05831 , -34.048904, ..., -35.888298,
         -42.165287, -39.020435]],

       [[-30.070187, -32.049232, -34.63928 , ..., -35.02118 ,
         -32.14465 , -32.891876],
        [-61.720093, -32.994057, -32.988144, ..., -42.054638,
         -42.18583 , -38.990112],
        [-57.74364 , -31.431454, -35.969643, ..., -38.593002,
         -42.276768, -38.895355],
        ...,
        [-58.677704, -23.567434, -35.6751  , ..., -36.018696,
         -40.343582, -38.681267],
        [-58.682228, -23.563087, -35.668964, ..., -36.019753,
         -40.336178, -38.67661 ],
        [-58.718002, -23.609531, -35.67758 , ..., -36.001644,
         -40.366055, -38.67864 ]],

       [[-30.320919, -33.430378, -34.84311 , ..., -37.259563,
         -32.59662 , -33.03912 ],
        [-61.275875, -34.824192, -34.07767 , ..., -44.637024,
         -41.718002, -38.974827],
        [-54.49349 , -30.689342, -35.539658, ..., -39.984665,
         -39.87059 , -37.038437],
        ...,
        [-58.939384, -23.831846, -34.525368, ..., -35.930893,
         -40.29633 , -37.637936],
        [-58.95117 , -23.824234, -34.520042, ..., -35.931396,
         -40.297188, -37.636852],
        [-58.966076, -23.795956, -34.519627, ..., -35.901787,
         -40.261116, -37.612514]]], dtype=float32), array([[[-1.43104442e-03, -2.98473001e-01,  9.49775204e-02, ...,
         -1.77978892e-02,  1.79805323e-01,  1.33578405e-01],
        [-2.35560730e-01,  1.53045550e-01,  5.15255742e-02, ...,
         -1.57466665e-01,  3.49459350e-01,  7.28092641e-02],
        [ 1.60562042e-02, -1.40354022e-01,  5.29232398e-02, ...,
         -2.38162443e-01, -7.72500336e-02,  6.80136457e-02],
        ...,
        [ 7.33550191e-02, -3.35853845e-01,  2.25579832e-03, ...,
         -1.93636306e-02,  1.08121082e-01,  5.24416938e-02],
        [ 8.32231194e-02, -3.11688155e-01, -2.13681534e-02, ...,
          3.23344418e-03,  1.08062990e-01,  7.20862746e-02],
        [ 9.58326831e-02, -3.00361574e-01, -3.02627794e-02, ...,
          3.01265554e-03,  1.20107472e-01,  9.56629887e-02]],

       [[-1.16950013e-01, -3.43173921e-01,  1.87818244e-01, ...,
         -2.71256089e-01,  7.42092952e-02,  5.77520356e-02],
        [-1.62564963e-01, -3.87467295e-01,  1.71134964e-01, ...,
         -7.83916116e-02, -3.65173034e-02,  2.08234787e-01],
        [-3.71523261e-01, -8.74521434e-02,  1.39187068e-01, ...,
         -3.08779895e-01,  3.88156146e-01,  9.99216512e-02],
        ...,
        [ 2.14628279e-02, -3.35561454e-01, -3.76663893e-03, ...,
         -1.29795140e-02,  1.44181430e-01,  1.15508482e-01],
        [ 3.47745977e-02, -3.30934107e-01,  1.10013550e-02, ...,
         -1.84394475e-02,  1.52143195e-01,  1.38157398e-01],
        [ 3.02720107e-02, -3.37626845e-01,  1.35379741e-02, ...,
         -3.80427912e-02,  1.50906458e-01,  1.38765752e-01]],

       [[-6.50129542e-02, -2.63762653e-01,  2.16862872e-01, ...,
         -1.66922837e-01,  1.09285273e-01, -6.40013069e-02],
        [-5.23199737e-01, -2.32228413e-01,  1.44963071e-01, ...,
         -1.41557693e-01,  1.90811172e-01, -2.22496167e-01],
        [-2.24985227e-01, -3.69372189e-01,  7.32450858e-02, ...,
          6.57786876e-02,  9.70033705e-02,  7.83021152e-02],
        ...,
        [-1.93579309e-03, -3.92921537e-01, -1.28203649e-02, ...,
         -8.74079913e-02,  1.13596492e-01,  9.25250202e-02],
        [ 4.55581211e-03, -3.65802884e-01, -2.60831695e-02, ...,
         -4.12549600e-02,  1.17429778e-01,  1.05997331e-01],
        [ 2.46201381e-02, -3.47863257e-01, -4.48134281e-02, ...,
         -2.53352951e-02,  1.16753690e-01,  1.36296600e-01]],

       ...,

       [[-6.47678748e-02, -3.45555365e-01,  7.19114989e-02, ...,
         -9.16809738e-02,  2.15520635e-01,  1.01671875e-01],
        [-7.61077851e-02, -1.51827012e-03,  9.52102616e-02, ...,
         -1.39335945e-01,  1.05894208e-01,  3.23191588e-03],
        [-3.24888170e-01, -2.17741728e-03,  5.32661797e-03, ...,
         -2.78430730e-01,  3.59415114e-01,  1.19439401e-01],
        ...,
        [ 6.89201057e-02, -3.63149673e-01,  7.96841756e-02, ...,
         -3.25191446e-04,  1.26513481e-01,  1.36511743e-01],
        [ 8.16355348e-02, -3.54205281e-01,  7.69739375e-02, ...,
         -2.90949806e-03,  1.31863236e-01,  1.56503588e-01],
        [ 8.36645439e-02, -3.38536322e-01,  8.00612345e-02, ...,
         -9.39210225e-03,  1.29102767e-01,  1.64855778e-01]],

       [[-1.63163885e-01, -3.34902078e-01,  1.11728966e-01, ...,
         -1.10363133e-01,  1.19786285e-01, -9.18702483e-02],
        [-3.36889774e-01, -3.34888607e-01,  1.30680993e-01, ...,
          1.22191897e-03,  1.45059675e-01, -1.27688542e-01],
        [-5.92090450e-02, -2.07585752e-01,  2.05589265e-01, ...,
         -6.80094585e-02,  2.11224273e-01,  3.92790437e-01],
        ...,
        [ 4.86238785e-02, -4.19503808e-01, -3.39424387e-02, ...,
         -1.76134892e-02,  1.00283481e-01,  1.38210282e-01],
        [ 5.81516996e-02, -4.04477298e-01, -4.19086292e-02, ...,
         -1.02474755e-02,  1.06062084e-01,  1.59754634e-01],
        [ 6.70261905e-02, -3.86263877e-01, -4.19785343e-02, ...,
          9.05385148e-03,  1.01594023e-01,  1.69663757e-01]],

       [[-1.22184128e-01, -3.67584258e-01,  3.60302597e-01, ...,
         -4.39502299e-02,  1.33717149e-01,  1.53699834e-02],
        [-3.37780178e-01, -4.05100137e-01,  2.02614054e-01, ...,
         -5.41410968e-02,  1.55447468e-01, -9.28792357e-02],
        [ 1.81227952e-01, -2.29236633e-01,  2.40814224e-01, ...,
          1.39913429e-02,  7.61386827e-02,  3.62152725e-01],
        ...,
        [ 1.47830993e-02, -4.26465064e-01, -1.54972840e-02, ...,
          3.74358669e-02,  1.52016997e-01,  1.53155088e-01],
        [ 3.46656404e-02, -4.00052220e-01, -3.53843644e-02, ...,
          2.64652576e-02,  1.62517026e-01,  1.66649833e-01],
        [ 4.50411513e-02, -3.61773074e-01, -5.50217964e-02, ...,
          3.68298292e-02,  1.67936400e-01,  1.76781893e-01]]],
      dtype=float32))

I thought that maybe I need to take argmax from these values but then I still get errors.

If something is unclear I would be happy to provide additional information. Thanks for any help.

Hi @snork-maiden , were you able to resolve this? I am facing the same challenge.
I referred to Not sure how to compute BLEU through compute_metrics - #5 by xiami using Seq2SeqTrainingArguments and set predict_with_generate=True in Seq2SeqTrainingArguments and got TypeError: argument 'ids': 'list' object cannot be interpreted as an integer

the shape of predictions should be 2, but we got 3 instead … …