edit:
run with a gpu;
I have stumbled upon a quirk while trying to figure out how to calculate custom metrics, using a setup similar to: Object detection
using a detr model for object detection, when attempting to use a custom metric the expanded label field is merged in a way that is impossible to use an associate with the original input
def compute_metrics(eval_pred):
(loss_dict, logits, pred_boxes, last_hidden_state, encoder_last_hidden_state), batched_labels = eval_pred
print(batched_labels)
return {"dummy": 1}
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=cppe5["train"],
eval_dataset=cppe5["train"],
tokenizer=image_processor,
compute_metrics=compute_metrics,
)
trainer.evaluate(cppe5["train"].select(range(16)))
[{'size': array([800, 800, 800, 800]), 'image_id': array([ 15, 285]), 'class_labels': array([4, 4, 0, 0, 0]), 'boxes': array([[0.64103925, 0.2036199 , 0.07741249, 0.07843136],
[0.11081654, 0.1719457 , 0.06044539, 0.04223229],
[0.6988335 , 0.5113122 , 0.26299044, 0.9291101 ],
[0.10710498, 0.40497738, 0.21420996, 0.60482657],
[0.568 , 0.53099996, 0.544 , 0.938 ]], dtype=float32), 'area': array([ 10544.444, 4433.333, 424355.53 , 225005.55 , 265766.66 ],
dtype=float32), 'iscrowd': array([0, 0, 0, 0, 0]), 'orig_size': array([480, 480, 480, 480])}, {'size': array([800, 800, 800, 800]), 'image_id': array([762, 834]), 'class_labels': array([4, 4, 4, 2, 2, 2, 2, 2, 0, 0, 3, 0, 4]), 'boxes': array([[0.8878005 , 0.3708278 , 0.06233307, 0.0847797 ],
[0.73219055, 0.35580772, 0.07791626, 0.06675567],
[0.50845945, 0.3538051 , 0.05520931, 0.06275036],
[0.7609083 , 0.57142854, 0.06589493, 0.06008007],
[0.8566341 , 0.6124833 , 0.03116646, 0.01668892],
[0.8210152 , 0.63818425, 0.05164741, 0.06408546],
[0.6622885 , 0.5257009 , 0.0503117 , 0.06074768],
[0.61130893, 0.5487316 , 0.06678536, 0.08144192],
[0.91184324, 0.5684245 , 0.17542297, 0.453271 ],
[0.5331701 , 0.4289052 , 0.12600182, 0.17022698],
[0.5347284 , 0.31041387, 0.08014248, 0.05473965],
[0.35507813, 0.63125 , 0.26953125, 0.7375 ],
[0.38125 , 0.46805555, 0.096875 , 0.09722222]], dtype=float32), 'area': array([ 49388.887, 48611.11 , 32377.777, 37000. , 4861.111,
30933.332, 28563.889, 50833.332, 743127.75 , 200458.33 ,
41000. , 508874.97 , 24111.11 ], dtype=float32), 'iscrowd': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'orig_size': array([480, 480, 480, 480])}, {'size': array([800, 800, 800, 800]), 'image_id': array([656, 39]), 'class_labels': array([0, 4, 4, 4, 4, 4, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0]), 'boxes': array([[0.49375 , 0.5694444 , 0.49375 , 0.86111116],
[0.8590625 , 0.25833333, 0.12312508, 0.11222223],
[0.6671875 , 0.245 , 0.091875 , 0.08333336],
[0.5065625 , 0.24222222, 0.08187496, 0.12888888],
[0.40375 , 0.22277777, 0.06125 , 0.11 ],
[0.24124998, 0.30555555, 0.0775 , 0.07555553],
[0.948125 , 0.78055555, 0.09625 , 0.13888885],
[0.801875 , 0.75722224, 0.0875 , 0.15 ],
[0.59625 , 0.74888885, 0.12875 , 0.18 ],
[0.4878125 , 0.6783333 , 0.091875 , 0.15666664],
[0.3890625 , 0.70555556, 0.076875 , 0.12 ],
[0.349375 , 0.745 , 0.06124996, 0.15666671],
[0.200625 , 0.8416667 , 0.0575 , 0.17000008],
[0.878125 , 0.5733333 , 0.2 , 0.85333335],
[0.6609375 , 0.5422222 , 0.228125 , 0.9155556 ],
[0.40875 , 0.5516667 , 0.18624996, 0.89666665],
[0.225625 , 0.55 , 0.24875 , 0.9 ]], dtype=float32), 'area': array([1088444.4 , 55269.44 , 30624.998, 42211.11 , 26949.998,
23422.22 , 53472.22 , 52499.996, 92700. , 57574.996,
36900. , 38383.332, 39100. , 682666.6 , 835444.44 ,
668016.6 , 895499.94 ], dtype=float32), 'iscrowd': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'orig_size': array([480, 480, 480, 480])}, {'size': array([800, 800, 800, 800]), 'image_id': array([419, 49]), 'class_labels': array([4, 1, 1, 1, 4, 4, 2, 2, 0, 0]), 'boxes': array([[0.48874995, 0.5265625 , 0.2425 , 0.178125 ],
[0.4875 , 0.47625 , 0.435 , 0.35375 ],
[0.41666663, 0.30200002, 0.06800003, 0.14800003],
[0.33866668, 0.353 , 0.10399998, 0.19 ],
[0.42466667, 0.301 , 0.04399998, 0.07799999],
[0.34199998, 0.377 , 0.05199997, 0.09800003],
[0.44666663, 0.62700003, 0.11200001, 0.11800003],
[0.23333333, 0.525 , 0.04266666, 0.10200001],
[0.31 , 0.585 , 0.23599999, 0.742 ],
[0.44266662, 0.556 , 0.21333332, 0.732 ]], dtype=float32), 'area': array([230374.98 , 820699.94 , 10483.333 , 20583.332 , 3574.9998,
5308.333 , 13766.666 , 4533.333 , 182408.33 , 162666.66 ],
dtype=float32), 'iscrowd': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'orig_size': array([480, 480, 480, 480])}, {'size': array([800, 800, 800, 800]), 'image_id': array([483, 900]), 'class_labels': array([4, 3, 4, 3, 4, 2, 2, 2, 2, 0, 0]), 'boxes': array([[0.61621094, 0.35054943, 0.11230469, 0.16336995],
[0.7 , 0.3126649 , 0.06400002, 0.05540899],
[0.71099997, 0.35751975, 0.05 , 0.07651713],
[0.289 , 0.16886543, 0.10200001, 0.07387863],
[0.294 , 0.20580475, 0.06799999, 0.05804748],
[0.69699997, 0.707124 , 0.09799995, 0.0949868 ],
[0.63 , 0.6873351 , 0.072 , 0.09762539],
[0.33699998, 0.5501319 , 0.11399999, 0.08707123],
[0.172 , 0.5131926 , 0.05599999, 0.06596302],
[0.791 , 0.58707124, 0.27800003, 0.8205804 ],
[0.345 , 0.5092348 , 0.37800002, 0.9762533 ]], dtype=float32), 'area': array([142472.22 , 1866.6666, 2013.8888, 3966.6665, 2077.7776,
4900. , 3699.9998, 5225. , 1944.4443, 120080.555 ,
194250. ], dtype=float32), 'iscrowd': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'orig_size': array([480, 480, 480, 480])}, {'size': array([800, 800, 800, 800]), 'image_id': array([192, 898]), 'class_labels': array([4, 4, 4, 4, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 4]), 'boxes': array([[0.77111113, 0.22444445, 0.05777778, 0.11555557],
[0.52611107, 0.15333334, 0.06777775, 0.09777777],
[0.3611111 , 0.24999999, 0.05111115, 0.08666664],
[0.19166666, 0.25 , 0.07666665, 0.1 ],
[0.85444444, 0.7833333 , 0.08666664, 0.18 ],
[0.68388885, 0.74444443, 0.05 , 0.21777779],
[0.63888884, 0.6755556 , 0.06444439, 0.18222222],
[0.4377778 , 0.7277777 , 0.09111111, 0.15777779],
[0.28444442, 0.7544444 , 0.05777777, 0.20222221],
[0.12111111, 0.77777773, 0.07111111, 0.10222214],
[0.76555556, 0.5088889 , 0.23333336, 0.9288889 ],
[0.55277777, 0.5044445 , 0.21222222, 0.9777778 ],
[0.3572222 , 0.52111113, 0.13444445, 0.9577778 ],
[0.1711111 , 0.5288889 , 0.20444442, 0.7955556 ],
[0.4296875 , 0.28125 , 0.103125 , 0.0875 ]], dtype=float32), 'area': array([ 7511.111 , 7455.555 , 4983.333 , 8625. , 17550. ,
12250. , 13211.11 , 16172.222 , 13144.444 , 8177.7773,
243833.33 , 233444.44 , 144863.89 , 182977.77 , 23100. ],
dtype=float32), 'iscrowd': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'orig_size': array([480, 480, 480, 480])}, {'size': array([800, 800, 800, 800]), 'image_id': array([722, 611]), 'class_labels': array([0, 4, 4, 4, 4, 4, 4, 2, 2, 2, 2, 2, 2, 3, 3, 3, 0, 0, 0, 0]), 'boxes': array([[0.525 , 0.6805556 , 0.3640625 , 0.6388889 ],
[0.5445312 , 0.7451389 , 0.1078125 , 0.17916672],
[0.99375 , 0.51555556, 0.0125 , 0.04888889],
[0.96 , 0.47111106, 0.0175 , 0.04444443],
[0.7 , 0.3188889 , 0.0575 , 0.11777779],
[0.43375 , 0.44888887, 0.04500004, 0.07555553],
[0.201875 , 0.39 , 0.03874998, 0.05111115],
[0.655 , 0.85444444, 0.055 , 0.10444443],
[0.51875 , 0.7533333 , 0.0425 , 0.04000008],
[0.463125 , 0.7888889 , 0.06625004, 0.10222221],
[0.344375 , 0.9733333 , 0.05125 , 0.04888893],
[0.08125 , 0.93 , 0.0575 , 0.13111107],
[0.253125 , 0.85 , 0.03875002, 0.12222221],
[0.690625 , 0.27888888, 0.08375 , 0.1 ],
[0.43125 , 0.4077778 , 0.0625 , 0.05111111],
[0.2075 , 0.3488889 , 0.05249998, 0.04444443],
[0.770625 , 0.56999993, 0.23625 , 0.8555555 ],
[0.941875 , 0.6822222 , 0.10375 , 0.6222222 ],
[0.436875 , 0.73333335, 0.19874996, 0.52444446],
[0.15375 , 0.60999995, 0.2425 , 0.76666665]], dtype=float32), 'area': array([595444.44 , 49450. , 611.1111, 777.7778, 6772.222 ,
3399.9998, 1980.5554, 5744.4443, 1699.9999, 6772.222 ,
2505.5554, 7538.8887, 4736.111 , 8375. , 3194.4443,
2333.3333, 202125. , 64555.555 , 104233.33 , 185916.66 ],
dtype=float32), 'iscrowd': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'orig_size': array([480, 480, 480, 480])}, {'size': array([800, 800, 800, 800]), 'image_id': array([965, 756]), 'class_labels': array([2, 2, 1, 0, 4]), 'boxes': array([[0.47208327, 0.635625 , 0.17583331, 0.13875003],
[0.47041664, 0.806875 , 0.11916664, 0.24625 ],
[0.6725 , 0.430625 , 0.1716666 , 0.43625 ],
[0.75 , 0.5425 , 0.49666667, 0.915 ],
[0.73398435, 0.6986111 , 0.34140626, 0.5944444 ]], dtype=float32), 'area': array([ 65058.332, 78252.77 , 199705.55 , 1211866.6 , 519544.44 ],
dtype=float32), 'iscrowd': array([0, 0, 0, 0, 0]), 'orig_size': array([480, 480, 480, 480])}]
there is no indication that boxes[0] originated from the first input nor how many boxes the other inputs have.
Is that unintended? am I missing something?
I figured that nested_concat
is the culprit for merging the list of boxes over the batches in this way
you can also see different bugs related to nested_concat with size
but that’s out of the scope of my current question