Binary semantic segmentation using SegFormer

Hello, I’m trying to perform binary segmentation, using SegFormer on a custom dataset, where the mask (PNG grayscale image) consists of value 1 for object of interest and 0 for everything else. I’ve used this tutorial and therefore I applied reduce_labels=True which leaves me with label 0 for object and 255 for background. However, when I pass labels to model I’m getting this error:

ValueError: The number of labels should be greater than one

How should I fix it, please? Should I consider region outside of my object not as a background, but rather as another label?

Hi,

The current implementation of SegFormer actually only supports multi-class classification per pixel, using the regular CrossEntropyLoss.

For binary classification per pixel, one should actually use the BCEWithLogitsLoss (which applies a sigmoid activation on the logits + computes cross-entropy loss).

I’ll add support for this! Thanks for bringing this up.

Hello,

Thank you for your answer. I’m still wondering if this couldn’t be overcome by pretending there are two classes (object and not object), therefore two labels and no background. Or something like a virtual background on 255 which simply won’t be present in training data.

Thanks.

Yes that’s totally possible, in that case you can just use the regular cross-entropy loss. Segmentation maps should in that case contain values 0 and 1 (which the model needs to learn for each pixel), and you can use 255 for pixels you want to ignore (i.e. the model shouldn’t learn). So you can set config.num_labels to 2 in that case.

Hello,

Thank you very much. Can you please reply when you add support for BCEWithLogitsLoss?

Thanks

Hi,

@Ivor22, I’ve opened a PR here: [SegFormer] Add support for segmentation masks with one label by NielsRogge · Pull Request #20279 · huggingface/transformers · GitHub.

If it works fine for you, I’ll add it to all other models.

As this seems to work now, should I use num_labels = 1 in this case?
I think the documentation still only refers and explains the num_labels >1 case?

And apologize, from the github pull request, is this now automatically using dice loss?