I have created a custom dataset and trained on it a custom T5ForConditionalGeneration
model that predicts solutions to quadratic equations like this:
Input: "4*x^2 + 4*x + 1"
Output: D = 4 ^ 2 - 4 * 4 * 1 4 * 1 4 * 1 4 * 1 4 * 1 4
I need to get accuracy for this model but I get only loss when I use Trainer
so I used a custom metric function (I didn’t write it but took it from a similar project):
def compute_metrics4token(eval_pred):
batch_size = 4
predictions, labels = eval_pred
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Rouge expects a newline after each sentence
decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
answer_accuracy = []
token_accuracy = []
num_correct, num_total = 0, 0
num_answer = 0
number_eq = 0
for p, l in zip(decoded_preds, decoded_labels):
text_pred = p.split(' ')
text_labels = l.split(' ')
m = min(len(text_pred), len(text_labels))
if np.array_equal(text_pred, text_labels):
num_answer += 1
for i, j in zip(text_pred, text_labels):
if i == j:
num_correct += 1
num_total += len(text_labels)
number_eq += 1
token_accuracy = num_correct / num_total
answer_accuracy = num_answer / number_eq
result = {'token_acc': token_accuracy, 'answer_acc': answer_accuracy}
result = {key: value for key, value in result.items()}
for key, value in result.items():
wandb.log({key: value})
return {k: round(v, 4) for k, v in result.items()}
Problem is that it doesn’t work and I don’t really understand why and what can I do to get accuracy for my model.
I get this error when I use the function:
args = Seq2SeqTrainingArguments(
output_dir='./',
num_train_epochs=10,
overwrite_output_dir = True,
evaluation_strategy = 'steps',
learning_rate = 1e-4,
logging_steps = 100,
eval_steps = 100,
save_steps = 100,
load_best_model_at_end = True,
push_to_hub=True,
weight_decay = 0.01,
per_device_train_batch_size=8,
per_device_eval_batch_size=4
)
trainer = Seq2SeqTrainer(model=model, train_dataset=train_dataset, eval_dataset=eval_dataset, args=args,
data_collator=data_collator, tokenizer=tokenizer, compute_metrics=compute_metrics4token)
<ipython-input-48-ff7980f6dd66> in compute_metrics4token(eval_pred)
4 # predictions = np.argmax(logits[0])
5 # print(predictions)
----> 6 decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
7 # Replace -100 in the labels as we can't decode them.
8 labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py in batch_decode(self, sequences, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
3444 `List[str]`: The list of decoded sentences.
3445 """
-> 3446 return [
3447 self.decode(
3448 seq,
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py in <listcomp>(.0)
3445 """
3446 return [
-> 3447 self.decode(
3448 seq,
3449 skip_special_tokens=skip_special_tokens,
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py in decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
3484 token_ids = to_py_obj(token_ids)
3485
-> 3486 return self._decode(
3487 token_ids=token_ids,
3488 skip_special_tokens=skip_special_tokens,
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_fast.py in _decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
547 if isinstance(token_ids, int):
548 token_ids = [token_ids]
--> 549 text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
550
551 clean_up_tokenization_spaces = (
TypeError: argument 'ids': 'list' object cannot be interpreted as an integer
When I print out predictions
I get a tuple:
(array([[[-32.777344, -34.593437, -36.065685, ..., -34.78577 ,
-34.77546 , -34.061115],
[-58.633934, -32.23472 , -31.735909, ..., -40.335655,
-40.28701 , -37.208904],
[-56.650974, -33.564095, -34.409576, ..., -36.94467 ,
-43.246735, -37.469246],
...,
[-56.62741 , -24.561722, -34.11228 , ..., -35.34798 ,
-42.287125, -38.889412],
[-56.632545, -24.470266, -34.0792 , ..., -35.313175,
-42.235626, -38.891712],
[-56.687027, -24.391508, -34.12526 , ..., -35.30828 ,
-42.204193, -38.88395 ]],
[[-29.79866 , -32.22621 , -32.689865, ..., -32.106445,
-31.46681 , -31.706667],
[-62.101192, -33.327423, -30.900173, ..., -38.046883,
-42.26345 , -38.97748 ],
[-54.726807, -29.13115 , -30.294558, ..., -28.370876,
-41.23722 , -37.91609 ],
...,
[-57.279373, -23.954525, -34.066246, ..., -35.047447,
-41.599922, -38.489853],
[-57.31298 , -23.879845, -34.0837 , ..., -35.03614 ,
-41.557755, -38.530064],
[-57.39132 , -23.831306, -34.120094, ..., -35.039547,
-41.525337, -38.55728 ]],
[[-29.858566, -32.452713, -34.05892 , ..., -33.93065 ,
-32.109177, -32.874695],
[-61.375793, -33.656853, -32.95248 , ..., -42.28087 ,
-42.637173, -39.21142 ],
[-58.43721 , -32.496166, -36.44046 , ..., -39.33864 ,
-42.139664, -38.695328],
...,
[-59.654663, -24.117435, -34.266438, ..., -35.734142,
-40.55384 , -38.467537],
[-38.54418 , -18.533113, -29.775307, ..., -26.856483,
-33.07976 , -29.934727],
[-27.716005, -14.610603, -23.752686, ..., -21.140053,
-26.855148, -24.429493]],
...,
[[-33.252697, -34.72487 , -36.395184, ..., -36.87368 ,
-35.207897, -34.468285],
[-59.911736, -32.730076, -32.622803, ..., -43.382267,
-42.25615 , -38.35135 ],
[-54.982887, -31.847572, -32.773827, ..., -38.500675,
-43.97969 , -37.41088 ],
...,
[-56.896988, -23.213766, -34.04734 , ..., -35.88832 ,
-42.176086, -38.953568],
[-56.994152, -23.141619, -34.054848, ..., -35.875816,
-42.176453, -38.97729 ],
[-57.076714, -23.05831 , -34.048904, ..., -35.888298,
-42.165287, -39.020435]],
[[-30.070187, -32.049232, -34.63928 , ..., -35.02118 ,
-32.14465 , -32.891876],
[-61.720093, -32.994057, -32.988144, ..., -42.054638,
-42.18583 , -38.990112],
[-57.74364 , -31.431454, -35.969643, ..., -38.593002,
-42.276768, -38.895355],
...,
[-58.677704, -23.567434, -35.6751 , ..., -36.018696,
-40.343582, -38.681267],
[-58.682228, -23.563087, -35.668964, ..., -36.019753,
-40.336178, -38.67661 ],
[-58.718002, -23.609531, -35.67758 , ..., -36.001644,
-40.366055, -38.67864 ]],
[[-30.320919, -33.430378, -34.84311 , ..., -37.259563,
-32.59662 , -33.03912 ],
[-61.275875, -34.824192, -34.07767 , ..., -44.637024,
-41.718002, -38.974827],
[-54.49349 , -30.689342, -35.539658, ..., -39.984665,
-39.87059 , -37.038437],
...,
[-58.939384, -23.831846, -34.525368, ..., -35.930893,
-40.29633 , -37.637936],
[-58.95117 , -23.824234, -34.520042, ..., -35.931396,
-40.297188, -37.636852],
[-58.966076, -23.795956, -34.519627, ..., -35.901787,
-40.261116, -37.612514]]], dtype=float32), array([[[-1.43104442e-03, -2.98473001e-01, 9.49775204e-02, ...,
-1.77978892e-02, 1.79805323e-01, 1.33578405e-01],
[-2.35560730e-01, 1.53045550e-01, 5.15255742e-02, ...,
-1.57466665e-01, 3.49459350e-01, 7.28092641e-02],
[ 1.60562042e-02, -1.40354022e-01, 5.29232398e-02, ...,
-2.38162443e-01, -7.72500336e-02, 6.80136457e-02],
...,
[ 7.33550191e-02, -3.35853845e-01, 2.25579832e-03, ...,
-1.93636306e-02, 1.08121082e-01, 5.24416938e-02],
[ 8.32231194e-02, -3.11688155e-01, -2.13681534e-02, ...,
3.23344418e-03, 1.08062990e-01, 7.20862746e-02],
[ 9.58326831e-02, -3.00361574e-01, -3.02627794e-02, ...,
3.01265554e-03, 1.20107472e-01, 9.56629887e-02]],
[[-1.16950013e-01, -3.43173921e-01, 1.87818244e-01, ...,
-2.71256089e-01, 7.42092952e-02, 5.77520356e-02],
[-1.62564963e-01, -3.87467295e-01, 1.71134964e-01, ...,
-7.83916116e-02, -3.65173034e-02, 2.08234787e-01],
[-3.71523261e-01, -8.74521434e-02, 1.39187068e-01, ...,
-3.08779895e-01, 3.88156146e-01, 9.99216512e-02],
...,
[ 2.14628279e-02, -3.35561454e-01, -3.76663893e-03, ...,
-1.29795140e-02, 1.44181430e-01, 1.15508482e-01],
[ 3.47745977e-02, -3.30934107e-01, 1.10013550e-02, ...,
-1.84394475e-02, 1.52143195e-01, 1.38157398e-01],
[ 3.02720107e-02, -3.37626845e-01, 1.35379741e-02, ...,
-3.80427912e-02, 1.50906458e-01, 1.38765752e-01]],
[[-6.50129542e-02, -2.63762653e-01, 2.16862872e-01, ...,
-1.66922837e-01, 1.09285273e-01, -6.40013069e-02],
[-5.23199737e-01, -2.32228413e-01, 1.44963071e-01, ...,
-1.41557693e-01, 1.90811172e-01, -2.22496167e-01],
[-2.24985227e-01, -3.69372189e-01, 7.32450858e-02, ...,
6.57786876e-02, 9.70033705e-02, 7.83021152e-02],
...,
[-1.93579309e-03, -3.92921537e-01, -1.28203649e-02, ...,
-8.74079913e-02, 1.13596492e-01, 9.25250202e-02],
[ 4.55581211e-03, -3.65802884e-01, -2.60831695e-02, ...,
-4.12549600e-02, 1.17429778e-01, 1.05997331e-01],
[ 2.46201381e-02, -3.47863257e-01, -4.48134281e-02, ...,
-2.53352951e-02, 1.16753690e-01, 1.36296600e-01]],
...,
[[-6.47678748e-02, -3.45555365e-01, 7.19114989e-02, ...,
-9.16809738e-02, 2.15520635e-01, 1.01671875e-01],
[-7.61077851e-02, -1.51827012e-03, 9.52102616e-02, ...,
-1.39335945e-01, 1.05894208e-01, 3.23191588e-03],
[-3.24888170e-01, -2.17741728e-03, 5.32661797e-03, ...,
-2.78430730e-01, 3.59415114e-01, 1.19439401e-01],
...,
[ 6.89201057e-02, -3.63149673e-01, 7.96841756e-02, ...,
-3.25191446e-04, 1.26513481e-01, 1.36511743e-01],
[ 8.16355348e-02, -3.54205281e-01, 7.69739375e-02, ...,
-2.90949806e-03, 1.31863236e-01, 1.56503588e-01],
[ 8.36645439e-02, -3.38536322e-01, 8.00612345e-02, ...,
-9.39210225e-03, 1.29102767e-01, 1.64855778e-01]],
[[-1.63163885e-01, -3.34902078e-01, 1.11728966e-01, ...,
-1.10363133e-01, 1.19786285e-01, -9.18702483e-02],
[-3.36889774e-01, -3.34888607e-01, 1.30680993e-01, ...,
1.22191897e-03, 1.45059675e-01, -1.27688542e-01],
[-5.92090450e-02, -2.07585752e-01, 2.05589265e-01, ...,
-6.80094585e-02, 2.11224273e-01, 3.92790437e-01],
...,
[ 4.86238785e-02, -4.19503808e-01, -3.39424387e-02, ...,
-1.76134892e-02, 1.00283481e-01, 1.38210282e-01],
[ 5.81516996e-02, -4.04477298e-01, -4.19086292e-02, ...,
-1.02474755e-02, 1.06062084e-01, 1.59754634e-01],
[ 6.70261905e-02, -3.86263877e-01, -4.19785343e-02, ...,
9.05385148e-03, 1.01594023e-01, 1.69663757e-01]],
[[-1.22184128e-01, -3.67584258e-01, 3.60302597e-01, ...,
-4.39502299e-02, 1.33717149e-01, 1.53699834e-02],
[-3.37780178e-01, -4.05100137e-01, 2.02614054e-01, ...,
-5.41410968e-02, 1.55447468e-01, -9.28792357e-02],
[ 1.81227952e-01, -2.29236633e-01, 2.40814224e-01, ...,
1.39913429e-02, 7.61386827e-02, 3.62152725e-01],
...,
[ 1.47830993e-02, -4.26465064e-01, -1.54972840e-02, ...,
3.74358669e-02, 1.52016997e-01, 1.53155088e-01],
[ 3.46656404e-02, -4.00052220e-01, -3.53843644e-02, ...,
2.64652576e-02, 1.62517026e-01, 1.66649833e-01],
[ 4.50411513e-02, -3.61773074e-01, -5.50217964e-02, ...,
3.68298292e-02, 1.67936400e-01, 1.76781893e-01]]],
dtype=float32))
I thought that maybe I need to take argmax from these values but then I still get errors.
If something is unclear I would be happy to provide additional information. Thanks for any help.