Modifying ViT to include 4th channel

Hello!

I have a series of image data with RGB channels, but I have also added another channel containing a segmentation output that could potentially help the model with the classification task I am working on. I am utilizing the code from the image classification github (link) to get started, but I notice that it takes 3-channel images. How would I begin to modify the code to take in a 4-channel image?

I appreciate any advice or help!

Hi,
did you find a solution?

I think you can do just the loading part of the image with a single character change, but I don’t know if the transformers library can handle this correctly…

def pil_loader(path: str):
    with open(path, "rb") as f:
        im = Image.open(f)
        return im.convert("RGBA") # 4 channels
        #return im.convert("RGB") # 3 channels