Inference Segformer

I finetuned a nvidia/mit-b5 for semantic segmentation. Training went well but during the inference is predicting the whole picture a mask?? I can’t correlate what I did wrong.

First, rescale logits to original image size

upsampled_logits = nn.functional.interpolate(logits,
size=image.shape[:-1], # (height, width)
mode=‘bilinear’,
align_corners=False)

Second, apply argmax on the class dimension

seg = upsampled_logits.argmax(dim=1)[0]
seg = seg.cpu().numpy()
print(np.unique(np.array(seg), return_counts=True))

there is only one value “1” for the whole image.

Can someone please help me? Thanks