Fine-Tuning with Grayscale Custom Dataset

Hi

This is actually my 1st time working with HF. I am trying to Fine-Tune a pre-trained model on my custom grayscale data which is for Chest X-Ray.

I have loaded and preprocessed the data and everything works fine. However, when I start training I am getting the following dimensions error:
RuntimeError: stack expects each tensor to be equal size, but got [3, 224, 303] at entry 0 and [3, 224, 262] at entry 1

I have the following Transformation applied on the data using PyTorch:

class XRayTransform:
    """
    Transforms for pre-processing XRay data across a batch.
    """
    def __init__(self):
        self.transforms = transforms.Compose([
            transforms.Lambda(lambda pil_img: pil_img.convert("RGB")),
            transforms.Resize(feature_extractor.size),
            transforms.ToTensor(),
            transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
        ])

    def __call__(self, example_batch):
        example_batch["pixel_values"] = [self.transforms(pil_img) for pil_img in example_batch["image"]]
        return example_batch

Running it with set_transform

Previously I was running the following functions taken from HF docs:

def process_example(example):
    inputs = feature_extractor(example['image'], return_tensors='pt')
    inputs['labels'] = example['labels']
    return inputs

def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['label'] = example_batch['label']
    return inputs

But I believe they are not suitable for grayscale data. I dont think that there is a problem with my Trainer function:

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    tokenizer=feature_extractor,
)

Hi, so the issue is here:

Namely, if you only provide an integer to torchvision.transforms.Resize, it will only resize the shorter edge of the image to match that number. To make sure all images are square, you would need to provide something like transforms.Resize((224,224)).

Works like a charm.

Thanks dear Nieslr