This was helpful. I implemented my own solution similar to the one on github. So far, no errors. I’ll add it here in case someone else finds it useful, or if someone finds bugs I missed. Either way, please let me know. Thanks!
@dataclass
class ModelOutput:
logits: torch.Tensor
pred_boxes: torch.Tensor
class BatchMAPEvaluator:
def __init__(self, image_processor, threshold=0.00, id2label=None):
self.image_processor = image_processor
self.threshold = threshold
self.id2label = id2label
self.evaluator = MeanAveragePrecision(box_format="xyxy", class_metrics=True)
self.evaluator.warn_on_many_detections = False
def reset(self):
"""Reset the evaluator state for a new evaluation run."""
self.evaluator.reset()
def process_batch(self, predictions, targets):
"""
Process a single batch and update the evaluator.
Args:
predictions: tuple of (loss, logits, pred_boxes)
targets: list of dicts, each with "size", "boxes", "class_labels" as tensors
"""
# Get image sizes - targets is a list of dicts with tensor values
image_sizes = torch.stack([x["size"] for x in targets]).cpu()
# Process predictions
batch_logits = predictions[1]
batch_boxes = predictions[2]
# Ensure tensors are on CPU
if isinstance(batch_logits, torch.Tensor):
batch_logits = batch_logits.cpu()
else:
batch_logits = torch.tensor(batch_logits)
if isinstance(batch_boxes, torch.Tensor):
batch_boxes = batch_boxes.cpu()
else:
batch_boxes = torch.tensor(batch_boxes)
output = ModelOutput(logits=batch_logits, pred_boxes=batch_boxes)
post_processed_predictions = self.image_processor.post_process_object_detection(
output, threshold=self.threshold, target_sizes=image_sizes
)
# Process targets
post_processed_targets = []
for target, (height, width) in zip(targets, image_sizes):
# Move tensors to CPU and convert to numpy
boxes = target["boxes"].cpu().numpy()
labels = target["class_labels"].cpu() # Already a torch tensor, just moved to CPU
# Convert xcycwh to xyxy and scale to image size
boxes = sv.xcycwh_to_xyxy(boxes)
boxes = boxes * np.array([width.item(), height.item(), width.item(), height.item()])
post_processed_targets.append({
"boxes": torch.tensor(boxes),
"labels": labels
})
# Update evaluator with this batch
self.evaluator.update(post_processed_predictions, post_processed_targets)
def compute(self):
"""Compute final metrics after all batches have been processed."""
metrics = self.evaluator.compute()
classes = metrics.pop("classes")
map_per_class = metrics.pop("map_per_class")
mar_100_per_class = metrics.pop("mar_100_per_class")
for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class):
class_name = self.id2label[class_id.item()] if self.id2label is not None else class_id.item()
metrics[f"map_{class_name}"] = class_map
metrics[f"mar_100_{class_name}"] = class_mar
metrics = {k: round(v.item(), 4) for k, v in metrics.items()}
return metrics
@torch.no_grad()
def __call__(self, evaluation_results, compute_result: bool):
if not compute_result:
predictions = evaluation_results.predictions
targets = evaluation_results.label_ids
self.process_batch(predictions, targets)
return {}
else:
metrics = self.compute()
self.reset()
return metrics
# Create batched map evaluator instance
batch_eval_compute_metrics_fn = BatchMAPEvaluator(image_processor=processor, threshold=0.30, id2label=id2label)
# Set batch_eval_metrics in TrainingArguments
training_args = TrainingArguments(
...
batch_eval_metrics=True,
)
# Add batched map evaluator instance to Trainer
trainer = Trainer(
...
compute_metrics=batch_eval_compute_metrics_fn,
)