As the title suggests, I want to fine-tune SAM in batches with varying numbers of points.

I couldn’t find an example for fine-tuning SAM that uses multiple points for fine-tuning only bounding boxes like this example.

This is the gist of my code. input_points is currently just a list of lists of lists of xy coordinates (shape [batch, 1, num_pts, 2]).

It can’t be stacked and converted to a torch tensor because `num_pts`

varies between each sample.

```
class MyDataset(Dataset):
def __init__(self):
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
def __getitem__(self, idx):
num_pts = random.randint(1, 5)
point_prompts = [(random.random(), random.random()) for _ in range(num_pts)]
image = np.random.rand(224, 224, 3) * 255
image = image.astype(np.uint8)
inputs = self.processor(image, input_points=[point_prompts])
return inputs
def train():
model = SamModel.from_pretrained("facebook/sam-vit-base")
for batch in train_loader:
outputs = model(pixel_values=pixel_values,
input_points=prompts,
multimask_output=False)
```

I read in the documentation that the processor pads the points to make `num_pts`

the same across all samples, but I don’t understand how this is supposed to work.

The points can be obtained by passing a list of list of list to the processor that will create corresponding `torch`

tensors of dimension 4. The first dimension is the image batch size, the second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict per input point), the third dimension is the number of points per segmentation mask (it is possible to pass multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) coordinates of the point. If a different number of points is passed either for each image, or for each mask, the processor will create “PAD” points that will correspond to the (0, 0) coordinate, and the computation of the embedding will be skipped for these points using the labels.

1 Like

I was able to make a little progress with my problem.

It seems like you have to pass your images and input_points batch-wise to the SamProcessor and in addition input_labels (see here).

However, this only works bug-free when adjusting `_pad_points_and_labels`

in the source code of `processing_sam.py`

in the following way:

```
def _pad_points_and_labels(self, input_points, input_labels):
r"""
The method pads the 2D points and labels to the maximum number of points in the batch.
"""
expected_nb_points = max([point.shape[0] for point in input_points])
processed_input_points = []
for i, point in enumerate(input_points):
if point.shape[0] != expected_nb_points:
input_labels[i] = np.append(input_labels[i],
[self.point_pad_value] * (expected_nb_points - point.shape[0]))
point = np.concatenate(
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + self.point_pad_value], axis=0
)
processed_input_points.append(point)
input_points = processed_input_points
return input_points, input_labels
```

I moved the processing part form the` __getitem__`

method to the collate_fn in the data loader.

As I said, this works, but it can’t be the intended way…

1 Like

Any updates on this? Funnily enough I’m in the exact same situation as you, trying to follow the same fine-tune SAM tutorial but I need to use control points instead of bounding boxes. Haven’t found a single working tutorial online training the model using control points. For some reason I have issues due to the `point_batch_size`

whatever that is… is this value related to the padding at all? The documentation is so vague…

The most elegant solution I could come up with is overwriting the `_pad_points_and_labels`

function from the `SamProcessor`

, passing just the batch of images to the `SamImageProcessor`

and doing the point padding outside of that function:

```
def _pad_points_and_labels(input_points, input_labels):
r"""
The method pads the 2D points and labels to the maximum number of points in the batch.
Adapted from processing_sam.py
"""
point_pad_value = -10
input_points = [np.array(point) for point in input_points]
expected_nb_points = max([point.shape[0] for point in input_points])
processed_input_points = []
for i, point in enumerate(input_points):
if point.shape[0] != expected_nb_points:
input_labels[i] = np.append(input_labels[i],
[point_pad_value] * (expected_nb_points - point.shape[0]))
point = np.concatenate(
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + point_pad_value], axis=0
)
input_labels[i] = torch.tensor(input_labels[i])
processed_input_points.append(torch.tensor(point))
input_points = processed_input_points
input_labels = torch.stack(input_labels, dim=0)
input_points = torch.stack(input_points, dim=0).unsqueeze(1)
return input_points, input_labels
```

This function I use in the collate function as follows:

```
def _custom_collate_fn(batch):
images = [sample.pop('cv_image') for sample in batch]
points = [sample.pop('point_prompts') for sample in batch]
input_labels = [sample.pop('point_labels') for sample in batch]
# pass the images separately to SamImageProcessor
inputs = image_processor(images, return_tensors="pt")
# pad points and labels outside the image_processor
inputs['input_points'], inputs['input_labels'] = _pad_points_and_labels(points, input_labels)
out_dict = default_collate(batch)
out_dict.update(inputs)
out_dict['cv_image'] = images
return out_dict
```

Hope this helps.