How to fine-tune Segment Anything Model (SAM) with multiple points

The most elegant solution I could come up with is overwriting the _pad_points_and_labels function from the SamProcessor, passing just the batch of images to the SamImageProcessor and doing the point padding outside of that function:

def _pad_points_and_labels(input_points, input_labels):
    r"""
    The method pads the 2D points and labels to the maximum number of points in the batch.
    Adapted from processing_sam.py
    """
    point_pad_value = -10
    input_points = [np.array(point) for point in input_points]
    expected_nb_points = max([point.shape[0] for point in input_points])
    processed_input_points = []
    for i, point in enumerate(input_points):
        if point.shape[0] != expected_nb_points:
            input_labels[i] = np.append(input_labels[i],
                                        [point_pad_value] * (expected_nb_points - point.shape[0]))
            point = np.concatenate(
                [point, np.zeros((expected_nb_points - point.shape[0], 2)) + point_pad_value], axis=0
            )
        input_labels[i] = torch.tensor(input_labels[i])
        processed_input_points.append(torch.tensor(point))
    input_points = processed_input_points

    input_labels = torch.stack(input_labels, dim=0)
    input_points = torch.stack(input_points, dim=0).unsqueeze(1)
    return input_points, input_labels

This function I use in the collate function as follows:

def _custom_collate_fn(batch):

    images = [sample.pop('cv_image') for sample in batch]
    points = [sample.pop('point_prompts') for sample in batch]
    input_labels = [sample.pop('point_labels') for sample in batch]

    # pass the images separately to SamImageProcessor
    inputs = image_processor(images, return_tensors="pt")

    # pad points and labels outside the image_processor
    inputs['input_points'], inputs['input_labels'] = _pad_points_and_labels(points, input_labels)

    out_dict = default_collate(batch)
    out_dict.update(inputs)
    out_dict['cv_image'] = images

    return out_dict

Hope this helps.