How do you use segmentation image processor with more than 3 channel images?

I’m trying to use Mask2Former with images with 6 channels. I thought this would work but the ImageProcessor throws an exception “ValueError: Unable to infer channel dimension format”

import numpy as np
import torch
from transformers import Mask2FormerConfig, Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor

# NUM_CHANNELS = 3 # Works!
NUM_CHANNELS = 6 # Does not work!

NUM_LABELS = 10
BACKBONE = "facebook/mask2former-swin-tiny-cityscapes-semantic"

model = Mask2FormerForUniversalSegmentation.from_pretrained(BACKBONE)
# Change to NUM_CHANNELS
model_config = model.config
model_config.backbone_config.num_channels = NUM_CHANNELS
model = Mask2FormerForUniversalSegmentation(model_config)

image_processor = Mask2FormerImageProcessor.from_pretrained(BACKBONE)

# Create a fake image and mask
image = np.random.randint(0, 256, (500, 500, NUM_CHANNELS)).astype('float32')
mask = np.random.randint(0, NUM_LABELS, (500, 500), dtype='uint8')
print(image.shape, mask.shape)

# Exception here ...
inputs = image_processor(image, segmentation_maps=mask, task_inputs=["semantic"], return_tensors="pt", input_data_format="channels_last")

with torch.no_grad():
    outputs = model(**inputs)

print(outputs.loss)