What on earth is point_batch_size for the transformers SamModel?

Hello everyone, after scouring the internet and refactoring my code several times, I am at my wit’s end with an issue related to the SamModel from the transformers package and hoping somebody on here may see what I am missing.

Background

I am trying to fine tune SAM. There are several tutorials doing this very thing online, however I have not been able to find a single one that uses control points, all use bounding boxes as inputs. Somehow, this seemingly trivial issue has ground all productivity to a halt.

When prompting the SAM model at inference time I was able to get predictions using control point inputs without issue by supplying control points in this form as per the documentation:

input_points: [[ 68 115], [ 79  45], [ 52 261], [ 44 169], [244 496], [115 495] ... ]
input_labels: [1, 1, 1, 1, 0, 0 ... ]

I assumed this would be the form for providing input points and labels at training time too and the documentation (kindof, we will get to that) seemed to say the same. If I print out an element of my dataset class and batch element, we get the following outputs, respectively:

# train_dataset[0].items()
pixel_values torch.Size([3, 1024, 1024])
original_sizes torch.Size([2])
reshaped_input_sizes torch.Size([2])
input_points torch.Size([5, 2])
input_labels torch.Size([5, 1])
ground_truth_mask (1024, 1024)

# batch.items()
pixel_values torch.Size([2, 3, 1024, 1024])
original_sizes torch.Size([2, 2])
reshaped_input_sizes torch.Size([2, 2])
input_points torch.Size([2, 5, 2])
input_labels torch.Size([2, 5, 1])
ground_truth_mask torch.Size([2, 1024, 1024])

Note that batch size is 2, not that I can see how that would matter. The relevant shapes are those of input_points and input_labels, with each having the following shapes in the batched item:

  • input_points torch.Size([2, 5, 2]): batch size, number of points, dimensionality of points (2 since they represent x,y coords)
  • input_labels torch.Size([2, 5, 1]): batch size, number of labels (equal to number of points), dimensionality of labels (1 since valid labels are either 1, 0, -1 or -10)

Issue

When I instantiate the SamProcessor and try to run a forward pass, I get the error below, complaining about the input_points not being a 4D tensor.

ValueError: ('The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.', ' got torch.Size([2, 5, 2]).')

This is puzzling for two reasons. Firstly, the documentation just doesn’t bother to mention what point_batch_size is anywhere that I can see. Secondly, the documentation claims input_points should be a 3D tensor of size ( batch_size, num_points, 2), which is exactly what I have, while the error claims otherwise!

Regardless, I still need to figure out what point_batch_size is, because it is included in the documentation as the size of one of the dimensions of input_labels, which is (batch_size, point_batch_size, num_points).

Does anybody know what on earth point_batch_size is supposed to be? And, is there any way to submit corrections to the documentation? When all is figured out, I’d like to contribute to the documentation somehow so other people don’t lose many hours on this silly issue too.

Code

For reference, in case it helps, here are some relevant parts of my source code. It is pretty directly adapted from this tutorial.

Dataset Source Code
from torch.utils.data import Dataset

class SAMDataset(Dataset):
  def __init__(self, dataset, processor):
    self.dataset = dataset
    self.processor = processor

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    datapoint = self.dataset[idx]
    image = datapoint["image"]
    ground_truth_mask = np.array(datapoint["label"])

    # Get control points prompt.
    input_points, input_labels = generateInputPointsFromMask(numPositive=5, mask=ground_truth_mask)
    input_points = input_points.astype(float).tolist()
    input_labels = input_labels.tolist()
    input_labels = [[x] for x in input_labels]
    
    # Prepare the image and prompt for the model.
    inputs = self.processor(image, input_points=input_points, input_labels=input_labels, return_tensors="pt")
    
    # Remove batch dimension which the processor adds by default.
    inputs = {k:v.squeeze(0) for k,v in inputs.items()}
    inputs["input_points"] = inputs["input_points"].squeeze(1)
    inputs["input_labels"] = inputs["input_labels"].squeeze(1)

    # Add ground truth segmentation.
    inputs["ground_truth_mask"] = ground_truth_mask

    return inputs
Model Instantiation Source Code
from transformers import SamModel, PreTrainedModel
from tqdm import tqdm
from statistics import mean

model = SamModel.from_pretrained("facebook/sam-vit-base")

num_epochs = 100
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.train()

for epoch in range(num_epochs):
  epoch_losses = []
  
  for batch in tqdm(train_dataloader):
    # Forward Pass
      
    outputs = model(
      pixel_values=batch["pixel_values"].to(device),
      input_points=batch["input_points"].to(device),
      input_labels=batch["input_labels"].to(device),
      multimask_output=False,
    )

    ...