Two questions about Segment Anything Model (SAM) in Transformers

Hi team! :wave:

I have two questions about the Segment Anything Model (SAM) available in the Transformers package.

  1. My project requires SAM to operate in two modes - generate all masks and generate masks based on the points prompt. The original implementation allowed me just to load SAM once and then pass it to SamAutomaticMaskGenerator if I want all masks or SamPredictor if I want masks based on the prompt.

I don’t see how I can do it with Transformers. I can use:

pipeline("mask-generation", model="facebook/sam-vit-huge", device=0)

if I want all the masks or:

model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

if I want the model to be promptable.

But will that require loading the model two times? Once for the pipeline and once for the processor?

  1. I want to detect multiple masks on a single image. In original SAM I can do predictor.set_image(image) to run encoder and predictor.predict to prompt model. That allows you to run predictor.predict` multiple times without encoding overhead. Is it possible with Transformers?

I would be very grateful if you could give me some help.

Hi,

Regarding your second question:

You can use the get_image_embeddings method to obtain image embeddings for your image first, after which you can use these in the forward (rather than providing pixel_values). This will look like this:

from transformers import SamModel
import torch

model = SamModel.from_pretrained("facebook/sam-vit-base")

pixel_values = torch.randn(1, 3, 1024, 1024)

image_embeddings = model.get_image_embeddings(pixel_values)

with torch.no_grad():
    outputs = model(image_embeddings=image_embeddings, ...)
1 Like

@lysandre has an answer to your first question. All pipelines allow to pass a model:

from transformers import SamModel, pipeline, SamProcessor

model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

pipe = pipeline('mask-generation', model=model, image_processor=processor)

Hence you only need to load the model once.

1 Like

@nielsr I tried to follow your instructions, but the answer to the second question is still not clear to me :confused: This is what I have:

import torch
from PIL import Image
from transformers import SamModel, SamImageProcessor

model_name = "facebook/sam-vit-huge"
device = "cuda" if torch.cuda.is_available() else "cpu"

model = SamModel.from_pretrained(model_name).to(device)
image_processor = SamImageProcessor.from_pretrained(model_name)

image = Image.open("...")
result = image_processor.preprocess(image)
pixel_values = np.array(result['pixel_values'])
pixel_values = torch.from_numpy(pixel_values).to(device)

image_embeddings = model.get_image_embeddings(pixel_values)

# this is me just trying to make my tensors to right sizes; not sure if that's what I need to do...
input_points = torch.from_numpy(aproximated_polygons[0][np.newaxis, np.newaxis, :]).to(device) # (1, 1, 5, 2)
input_labels = torch.from_numpy(np.ones(5)[np.newaxis, np.newaxis, :]).to(device) # (1, 1, 5)

with torch.no_grad():
    outputs = model.forward(image_embeddings=image_embeddings, input_points=input_points, input_labels=input_labels) # torch.Size([1, 1, 3, 256, 256])
  • Is there a better way to preprocess input_points and input_labels?
  • How to post-process outputs?

Hi,

First off, it’s recomended to just call the processor rather than doing image_processor.preprocess. I’ll copy the code snippet from the docs here, tweaked to reuse image embeddings:

import torch
from PIL import Image
import requests
from transformers import SamModel, SamProcessor

device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[450, 600]]]  # 2D location of a window in the image

inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(device)

# first, get the image embeddings
image_embeddings = model.get_image_embeddings(inputs.pixel_values)

# next, one can run the forward with the image embeddings on various prompts
del inputs["pixel_values")
with torch.no_grad():
    outputs = model(image_embeddings=image_embeddings, **inputs)

masks = processor.image_processor.post_process_masks(
    outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
)
scores = outputs.iou_scores
1 Like

Hey @nielsr,
Thanks for sharing the code snippets!
I’ve spotted the

get_image_embeddings()

function there. I find this related to a recent chat about a “smoother” timm interface for the transformers library

I understand that the SAM interface is a particular case but Id really like to get your input on the matter

A fan