CUDA Out Of Memory when training a DETR Object detection model with compute_metrics

This was helpful. I implemented my own solution similar to the one on github. So far, no errors. I’ll add it here in case someone else finds it useful, or if someone finds bugs I missed. Either way, please let me know. Thanks!

@dataclass
class ModelOutput:
    logits: torch.Tensor
    pred_boxes: torch.Tensor

class BatchMAPEvaluator:

def __init__(self, image_processor, threshold=0.00, id2label=None):
    self.image_processor = image_processor
    self.threshold = threshold
    self.id2label = id2label
    self.evaluator = MeanAveragePrecision(box_format="xyxy", class_metrics=True)
    self.evaluator.warn_on_many_detections = False

def reset(self):
    """Reset the evaluator state for a new evaluation run."""
    self.evaluator.reset()

def process_batch(self, predictions, targets):
    """
    Process a single batch and update the evaluator.
    
    Args:
        predictions: tuple of (loss, logits, pred_boxes)
        targets: list of dicts, each with "size", "boxes", "class_labels" as tensors
    """
    # Get image sizes - targets is a list of dicts with tensor values
    image_sizes = torch.stack([x["size"] for x in targets]).cpu()
    
    # Process predictions
    batch_logits = predictions[1]
    batch_boxes = predictions[2]
    
    # Ensure tensors are on CPU
    if isinstance(batch_logits, torch.Tensor):
        batch_logits = batch_logits.cpu()
    else:
        batch_logits = torch.tensor(batch_logits)
        
    if isinstance(batch_boxes, torch.Tensor):
        batch_boxes = batch_boxes.cpu()
    else:
        batch_boxes = torch.tensor(batch_boxes)
    
    output = ModelOutput(logits=batch_logits, pred_boxes=batch_boxes)
    post_processed_predictions = self.image_processor.post_process_object_detection(
        output, threshold=self.threshold, target_sizes=image_sizes
    )
    
    # Process targets
    post_processed_targets = []
    for target, (height, width) in zip(targets, image_sizes):
        # Move tensors to CPU and convert to numpy
        boxes = target["boxes"].cpu().numpy()
        labels = target["class_labels"].cpu()  # Already a torch tensor, just moved to CPU
        
        # Convert xcycwh to xyxy and scale to image size
        boxes = sv.xcycwh_to_xyxy(boxes)
        boxes = boxes * np.array([width.item(), height.item(), width.item(), height.item()])
        
        post_processed_targets.append({
            "boxes": torch.tensor(boxes),
            "labels": labels
        })
    
    # Update evaluator with this batch
    self.evaluator.update(post_processed_predictions, post_processed_targets)

def compute(self):
    """Compute final metrics after all batches have been processed."""
    metrics = self.evaluator.compute()

    classes = metrics.pop("classes")
    map_per_class = metrics.pop("map_per_class")
    mar_100_per_class = metrics.pop("mar_100_per_class")
    for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class):
        class_name = self.id2label[class_id.item()] if self.id2label is not None else class_id.item()
        metrics[f"map_{class_name}"] = class_map
        metrics[f"mar_100_{class_name}"] = class_mar

    metrics = {k: round(v.item(), 4) for k, v in metrics.items()}
    
    return metrics

@torch.no_grad()
def __call__(self, evaluation_results, compute_result: bool):
    if not compute_result:
        predictions = evaluation_results.predictions
        targets = evaluation_results.label_ids
        self.process_batch(predictions, targets)
        return {}
    else:
        metrics = self.compute()
        self.reset()
        return metrics

# Create batched map evaluator instance
batch_eval_compute_metrics_fn = BatchMAPEvaluator(image_processor=processor, threshold=0.30, id2label=id2label)

# Set batch_eval_metrics in TrainingArguments
training_args = TrainingArguments(
    ...
    batch_eval_metrics=True,
)

# Add batched map evaluator instance to Trainer
trainer = Trainer(
    ...
    compute_metrics=batch_eval_compute_metrics_fn,
)
1 Like