I am trying to fine-tune the Segment Anything (SAM) model following the recently-posted demo notebook (thanks, @nielsr and @ybelkada !).
I am trying to use 1024x1024 pixel images and masks, but when I try to calculate the loss with
loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1)), I get an error from
AssertionError: ground truth has different shape (torch.Size([1, 1, 1024, 1024])) from input (torch.Size([1, 1, 256, 256]))
Obviously, somewhere in the
model() call, the input 1024x1024 tensor is getting downsampled to 256x256. In the example notebook, the input image is 256x256 already, so the mask is as well.
I am wondering what is the best way to handle this. Should I simply downsample my masks before calculating the loss or is there some parameter I can change so that I can use 1024x1024 masks?
Thanks for your interest in my notebook! You could use torch.nn.functional.interpolate to interpolate the predicted masks to the appropriate size before calculating the loss:
from torch import nn
predicted_masks = nn.functional.interpolate(predicted_masks,
This is also used in the postprocessing method of
I was experimenting with the same notebook. In my case, I have trained on a custom dataset. Now, I want it to predict masks automatically (without any prompts like bounding boxes). How to do it?
I tried a few things (
input_boxes = [0,0,img_width, img_height], using
pipeline("mask-generator", model=my_model, processor = my_processor) from here and loading my checkpoint like in here) but they didn’t work.
To use the model finetuned with the pipeline you can save the processor and tuned model with save_pretrained and pass the same folder, after that you can pass to the pipeline the path where you save the with the save_pretrained method.
Hi, I’ve just made an issue about this as had the same problem. My inputs + masks are not all the same size so couldn’t work out how to implement the solution you suggested but just resized my masks instead. Just thought would be good to comment here incase others find this thread.