Hi,
Thanks for your interest in my notebook! You could use torch.nn.functional.interpolate to interpolate the predicted masks to the appropriate size before calculating the loss:
from torch import nn
predicted_masks = nn.functional.interpolate(predicted_masks,
size=(1024, 1024),
mode='bilinear',
align_corners=False)
This is also used in the postprocessing method of SamImageProcessor
.