How to fine-tune Segment Anything Model (SAM) with multiple points

As the title suggests, I want to fine-tune SAM in batches with varying numbers of points.
I couldn’t find an example for fine-tuning SAM that uses multiple points for fine-tuning only bounding boxes like this example.

This is the gist of my code. input_points is currently just a list of lists of lists of xy coordinates (shape [batch, 1, num_pts, 2]).
It can’t be stacked and converted to a torch tensor because num_pts varies between each sample.

class MyDataset(Dataset):

    def __init__(self):

        self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

    def __getitem__(self, idx):
        num_pts = random.randint(1, 5)
        point_prompts = [(random.random(), random.random()) for _ in range(num_pts)]

        image = np.random.rand(224, 224, 3) * 255
        image = image.astype(np.uint8)

        inputs = self.processor(image, input_points=[point_prompts])
        return inputs

def train():

    model = SamModel.from_pretrained("facebook/sam-vit-base")

    for batch in train_loader:
        outputs = model(pixel_values=pixel_values,
                        input_points=prompts,
                        multimask_output=False)

I read in the documentation that the processor pads the points to make num_pts the same across all samples, but I don’t understand how this is supposed to work.

The points can be obtained by passing a list of list of list to the processor that will create corresponding torch tensors of dimension 4. The first dimension is the image batch size, the second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict per input point), the third dimension is the number of points per segmentation mask (it is possible to pass multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) coordinates of the point. If a different number of points is passed either for each image, or for each mask, the processor will create “PAD” points that will correspond to the (0, 0) coordinate, and the computation of the embedding will be skipped for these points using the labels.

1 Like

I was able to make a little progress with my problem.
It seems like you have to pass your images and input_points batch-wise to the SamProcessor and in addition input_labels (see here).

However, this only works bug-free when adjusting _pad_points_and_labels in the source code of processing_sam.py in the following way:

def _pad_points_and_labels(self, input_points, input_labels):
    r"""
    The method pads the 2D points and labels to the maximum number of points in the batch.
    """
    expected_nb_points = max([point.shape[0] for point in input_points])
    processed_input_points = []
    for i, point in enumerate(input_points):
        if point.shape[0] != expected_nb_points:
            input_labels[i] = np.append(input_labels[i],
                                        [self.point_pad_value] * (expected_nb_points - point.shape[0]))
            point = np.concatenate(
                [point, np.zeros((expected_nb_points - point.shape[0], 2)) + self.point_pad_value], axis=0
            )
        processed_input_points.append(point)
    input_points = processed_input_points
    return input_points, input_labels

I moved the processing part form the __getitem__ method to the collate_fn in the data loader.

As I said, this works, but it can’t be the intended way…

1 Like

Any updates on this? Funnily enough I’m in the exact same situation as you, trying to follow the same fine-tune SAM tutorial but I need to use control points instead of bounding boxes. Haven’t found a single working tutorial online training the model using control points. For some reason I have issues due to the point_batch_size whatever that is… is this value related to the padding at all? The documentation is so vague…

The most elegant solution I could come up with is overwriting the _pad_points_and_labels function from the SamProcessor, passing just the batch of images to the SamImageProcessor and doing the point padding outside of that function:

def _pad_points_and_labels(input_points, input_labels):
    r"""
    The method pads the 2D points and labels to the maximum number of points in the batch.
    Adapted from processing_sam.py
    """
    point_pad_value = -10
    input_points = [np.array(point) for point in input_points]
    expected_nb_points = max([point.shape[0] for point in input_points])
    processed_input_points = []
    for i, point in enumerate(input_points):
        if point.shape[0] != expected_nb_points:
            input_labels[i] = np.append(input_labels[i],
                                        [point_pad_value] * (expected_nb_points - point.shape[0]))
            point = np.concatenate(
                [point, np.zeros((expected_nb_points - point.shape[0], 2)) + point_pad_value], axis=0
            )
        input_labels[i] = torch.tensor(input_labels[i])
        processed_input_points.append(torch.tensor(point))
    input_points = processed_input_points

    input_labels = torch.stack(input_labels, dim=0)
    input_points = torch.stack(input_points, dim=0).unsqueeze(1)
    return input_points, input_labels

This function I use in the collate function as follows:

def _custom_collate_fn(batch):

    images = [sample.pop('cv_image') for sample in batch]
    points = [sample.pop('point_prompts') for sample in batch]
    input_labels = [sample.pop('point_labels') for sample in batch]

    # pass the images separately to SamImageProcessor
    inputs = image_processor(images, return_tensors="pt")

    # pad points and labels outside the image_processor
    inputs['input_points'], inputs['input_labels'] = _pad_points_and_labels(points, input_labels)

    out_dict = default_collate(batch)
    out_dict.update(inputs)
    out_dict['cv_image'] = images

    return out_dict

Hope this helps.