How to go about calculating custom metrics when nested_concat merges labels irreversably

edit:
run with a gpu;

I have stumbled upon a quirk while trying to figure out how to calculate custom metrics, using a setup similar to: Object detection

using a detr model for object detection, when attempting to use a custom metric the expanded label field is merged in a way that is impossible to use an associate with the original input

def compute_metrics(eval_pred):
    (loss_dict, logits, pred_boxes, last_hidden_state, encoder_last_hidden_state), batched_labels = eval_pred
    print(batched_labels)
    return {"dummy": 1}
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=cppe5["train"],
    eval_dataset=cppe5["train"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)

trainer.evaluate(cppe5["train"].select(range(16)))

image
image

[{'size': array([800, 800, 800, 800]), 'image_id': array([ 15, 285]), 'class_labels': array([4, 4, 0, 0, 0]), 'boxes': array([[0.64103925, 0.2036199 , 0.07741249, 0.07843136],
       [0.11081654, 0.1719457 , 0.06044539, 0.04223229],
       [0.6988335 , 0.5113122 , 0.26299044, 0.9291101 ],
       [0.10710498, 0.40497738, 0.21420996, 0.60482657],
       [0.568     , 0.53099996, 0.544     , 0.938     ]], dtype=float32), 'area': array([ 10544.444,   4433.333, 424355.53 , 225005.55 , 265766.66 ],
      dtype=float32), 'iscrowd': array([0, 0, 0, 0, 0]), 'orig_size': array([480, 480, 480, 480])}, {'size': array([800, 800, 800, 800]), 'image_id': array([762, 834]), 'class_labels': array([4, 4, 4, 2, 2, 2, 2, 2, 0, 0, 3, 0, 4]), 'boxes': array([[0.8878005 , 0.3708278 , 0.06233307, 0.0847797 ],
       [0.73219055, 0.35580772, 0.07791626, 0.06675567],
       [0.50845945, 0.3538051 , 0.05520931, 0.06275036],
       [0.7609083 , 0.57142854, 0.06589493, 0.06008007],
       [0.8566341 , 0.6124833 , 0.03116646, 0.01668892],
       [0.8210152 , 0.63818425, 0.05164741, 0.06408546],
       [0.6622885 , 0.5257009 , 0.0503117 , 0.06074768],
       [0.61130893, 0.5487316 , 0.06678536, 0.08144192],
       [0.91184324, 0.5684245 , 0.17542297, 0.453271  ],
       [0.5331701 , 0.4289052 , 0.12600182, 0.17022698],
       [0.5347284 , 0.31041387, 0.08014248, 0.05473965],
       [0.35507813, 0.63125   , 0.26953125, 0.7375    ],
       [0.38125   , 0.46805555, 0.096875  , 0.09722222]], dtype=float32), 'area': array([ 49388.887,  48611.11 ,  32377.777,  37000.   ,   4861.111,
        30933.332,  28563.889,  50833.332, 743127.75 , 200458.33 ,
        41000.   , 508874.97 ,  24111.11 ], dtype=float32), 'iscrowd': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'orig_size': array([480, 480, 480, 480])}, {'size': array([800, 800, 800, 800]), 'image_id': array([656,  39]), 'class_labels': array([0, 4, 4, 4, 4, 4, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0]), 'boxes': array([[0.49375   , 0.5694444 , 0.49375   , 0.86111116],
       [0.8590625 , 0.25833333, 0.12312508, 0.11222223],
       [0.6671875 , 0.245     , 0.091875  , 0.08333336],
       [0.5065625 , 0.24222222, 0.08187496, 0.12888888],
       [0.40375   , 0.22277777, 0.06125   , 0.11      ],
       [0.24124998, 0.30555555, 0.0775    , 0.07555553],
       [0.948125  , 0.78055555, 0.09625   , 0.13888885],
       [0.801875  , 0.75722224, 0.0875    , 0.15      ],
       [0.59625   , 0.74888885, 0.12875   , 0.18      ],
       [0.4878125 , 0.6783333 , 0.091875  , 0.15666664],
       [0.3890625 , 0.70555556, 0.076875  , 0.12      ],
       [0.349375  , 0.745     , 0.06124996, 0.15666671],
       [0.200625  , 0.8416667 , 0.0575    , 0.17000008],
       [0.878125  , 0.5733333 , 0.2       , 0.85333335],
       [0.6609375 , 0.5422222 , 0.228125  , 0.9155556 ],
       [0.40875   , 0.5516667 , 0.18624996, 0.89666665],
       [0.225625  , 0.55      , 0.24875   , 0.9       ]], dtype=float32), 'area': array([1088444.4  ,   55269.44 ,   30624.998,   42211.11 ,   26949.998,
         23422.22 ,   53472.22 ,   52499.996,   92700.   ,   57574.996,
         36900.   ,   38383.332,   39100.   ,  682666.6  ,  835444.44 ,
        668016.6  ,  895499.94 ], dtype=float32), 'iscrowd': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'orig_size': array([480, 480, 480, 480])}, {'size': array([800, 800, 800, 800]), 'image_id': array([419,  49]), 'class_labels': array([4, 1, 1, 1, 4, 4, 2, 2, 0, 0]), 'boxes': array([[0.48874995, 0.5265625 , 0.2425    , 0.178125  ],
       [0.4875    , 0.47625   , 0.435     , 0.35375   ],
       [0.41666663, 0.30200002, 0.06800003, 0.14800003],
       [0.33866668, 0.353     , 0.10399998, 0.19      ],
       [0.42466667, 0.301     , 0.04399998, 0.07799999],
       [0.34199998, 0.377     , 0.05199997, 0.09800003],
       [0.44666663, 0.62700003, 0.11200001, 0.11800003],
       [0.23333333, 0.525     , 0.04266666, 0.10200001],
       [0.31      , 0.585     , 0.23599999, 0.742     ],
       [0.44266662, 0.556     , 0.21333332, 0.732     ]], dtype=float32), 'area': array([230374.98  , 820699.94  ,  10483.333 ,  20583.332 ,   3574.9998,
         5308.333 ,  13766.666 ,   4533.333 , 182408.33  , 162666.66  ],
      dtype=float32), 'iscrowd': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'orig_size': array([480, 480, 480, 480])}, {'size': array([800, 800, 800, 800]), 'image_id': array([483, 900]), 'class_labels': array([4, 3, 4, 3, 4, 2, 2, 2, 2, 0, 0]), 'boxes': array([[0.61621094, 0.35054943, 0.11230469, 0.16336995],
       [0.7       , 0.3126649 , 0.06400002, 0.05540899],
       [0.71099997, 0.35751975, 0.05      , 0.07651713],
       [0.289     , 0.16886543, 0.10200001, 0.07387863],
       [0.294     , 0.20580475, 0.06799999, 0.05804748],
       [0.69699997, 0.707124  , 0.09799995, 0.0949868 ],
       [0.63      , 0.6873351 , 0.072     , 0.09762539],
       [0.33699998, 0.5501319 , 0.11399999, 0.08707123],
       [0.172     , 0.5131926 , 0.05599999, 0.06596302],
       [0.791     , 0.58707124, 0.27800003, 0.8205804 ],
       [0.345     , 0.5092348 , 0.37800002, 0.9762533 ]], dtype=float32), 'area': array([142472.22  ,   1866.6666,   2013.8888,   3966.6665,   2077.7776,
         4900.    ,   3699.9998,   5225.    ,   1944.4443, 120080.555 ,
       194250.    ], dtype=float32), 'iscrowd': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'orig_size': array([480, 480, 480, 480])}, {'size': array([800, 800, 800, 800]), 'image_id': array([192, 898]), 'class_labels': array([4, 4, 4, 4, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 4]), 'boxes': array([[0.77111113, 0.22444445, 0.05777778, 0.11555557],
       [0.52611107, 0.15333334, 0.06777775, 0.09777777],
       [0.3611111 , 0.24999999, 0.05111115, 0.08666664],
       [0.19166666, 0.25      , 0.07666665, 0.1       ],
       [0.85444444, 0.7833333 , 0.08666664, 0.18      ],
       [0.68388885, 0.74444443, 0.05      , 0.21777779],
       [0.63888884, 0.6755556 , 0.06444439, 0.18222222],
       [0.4377778 , 0.7277777 , 0.09111111, 0.15777779],
       [0.28444442, 0.7544444 , 0.05777777, 0.20222221],
       [0.12111111, 0.77777773, 0.07111111, 0.10222214],
       [0.76555556, 0.5088889 , 0.23333336, 0.9288889 ],
       [0.55277777, 0.5044445 , 0.21222222, 0.9777778 ],
       [0.3572222 , 0.52111113, 0.13444445, 0.9577778 ],
       [0.1711111 , 0.5288889 , 0.20444442, 0.7955556 ],
       [0.4296875 , 0.28125   , 0.103125  , 0.0875    ]], dtype=float32), 'area': array([  7511.111 ,   7455.555 ,   4983.333 ,   8625.    ,  17550.    ,
        12250.    ,  13211.11  ,  16172.222 ,  13144.444 ,   8177.7773,
       243833.33  , 233444.44  , 144863.89  , 182977.77  ,  23100.    ],
      dtype=float32), 'iscrowd': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'orig_size': array([480, 480, 480, 480])}, {'size': array([800, 800, 800, 800]), 'image_id': array([722, 611]), 'class_labels': array([0, 4, 4, 4, 4, 4, 4, 2, 2, 2, 2, 2, 2, 3, 3, 3, 0, 0, 0, 0]), 'boxes': array([[0.525     , 0.6805556 , 0.3640625 , 0.6388889 ],
       [0.5445312 , 0.7451389 , 0.1078125 , 0.17916672],
       [0.99375   , 0.51555556, 0.0125    , 0.04888889],
       [0.96      , 0.47111106, 0.0175    , 0.04444443],
       [0.7       , 0.3188889 , 0.0575    , 0.11777779],
       [0.43375   , 0.44888887, 0.04500004, 0.07555553],
       [0.201875  , 0.39      , 0.03874998, 0.05111115],
       [0.655     , 0.85444444, 0.055     , 0.10444443],
       [0.51875   , 0.7533333 , 0.0425    , 0.04000008],
       [0.463125  , 0.7888889 , 0.06625004, 0.10222221],
       [0.344375  , 0.9733333 , 0.05125   , 0.04888893],
       [0.08125   , 0.93      , 0.0575    , 0.13111107],
       [0.253125  , 0.85      , 0.03875002, 0.12222221],
       [0.690625  , 0.27888888, 0.08375   , 0.1       ],
       [0.43125   , 0.4077778 , 0.0625    , 0.05111111],
       [0.2075    , 0.3488889 , 0.05249998, 0.04444443],
       [0.770625  , 0.56999993, 0.23625   , 0.8555555 ],
       [0.941875  , 0.6822222 , 0.10375   , 0.6222222 ],
       [0.436875  , 0.73333335, 0.19874996, 0.52444446],
       [0.15375   , 0.60999995, 0.2425    , 0.76666665]], dtype=float32), 'area': array([595444.44  ,  49450.    ,    611.1111,    777.7778,   6772.222 ,
         3399.9998,   1980.5554,   5744.4443,   1699.9999,   6772.222 ,
         2505.5554,   7538.8887,   4736.111 ,   8375.    ,   3194.4443,
         2333.3333, 202125.    ,  64555.555 , 104233.33  , 185916.66  ],
      dtype=float32), 'iscrowd': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'orig_size': array([480, 480, 480, 480])}, {'size': array([800, 800, 800, 800]), 'image_id': array([965, 756]), 'class_labels': array([2, 2, 1, 0, 4]), 'boxes': array([[0.47208327, 0.635625  , 0.17583331, 0.13875003],
       [0.47041664, 0.806875  , 0.11916664, 0.24625   ],
       [0.6725    , 0.430625  , 0.1716666 , 0.43625   ],
       [0.75      , 0.5425    , 0.49666667, 0.915     ],
       [0.73398435, 0.6986111 , 0.34140626, 0.5944444 ]], dtype=float32), 'area': array([  65058.332,   78252.77 ,  199705.55 , 1211866.6  ,  519544.44 ],
      dtype=float32), 'iscrowd': array([0, 0, 0, 0, 0]), 'orig_size': array([480, 480, 480, 480])}]

there is no indication that boxes[0] originated from the first input nor how many boxes the other inputs have.

Is that unintended? am I missing something?
I figured that nested_concat is the culprit for merging the list of boxes over the batches in this way
you can also see different bugs related to nested_concat with size but that’s out of the scope of my current question