I am trying to fine-tune the Segment Anything (SAM) model following the recently-posted demo notebook (thanks, @nielsr and @ybelkada !).
I am trying to use 1024x1024 pixel images and masks, but when I try to calculate the loss with loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))
, I get an error from monai
:
AssertionError: ground truth has different shape (torch.Size([1, 1, 1024, 1024])) from input (torch.Size([1, 1, 256, 256]))
Obviously, somewhere in the model()
call, the input 1024x1024 tensor is getting downsampled to 256x256. In the example notebook, the input image is 256x256 already, so the mask is as well.
I am wondering what is the best way to handle this. Should I simply downsample my masks before calculating the loss or is there some parameter I can change so that I can use 1024x1024 masks?