@harpergrieve Hereâs my collate_fn()
, if I recall I didnât change any of the functionality there:
def collate_fn(batch, image_processor):
pixel_values = [item["pixel_values"] for item in batch]
encoding = image_processor.pad(pixel_values, return_tensors="pt")
labels = [dict(item["labels"]) for item in batch]
batch = {"pixel_values": encoding["pixel_values"], "pixel_mask": encoding["pixel_mask"], "labels": labels}
# batch = {"pixel_values": encoding["pixel_values"], "labels": labels} # For YOLOS backend
return batch
My train/test datasets are created with:
train_ds = dataset["train"].with_transform(
lambda examples: transform_aug_ann(examples, image_processor, train_transform)
)
test_ds = dataset["test"].with_transform(
lambda examples: transform_aug_ann(examples, image_processor, test_transform)
)
where transform_aug_ann
is
def transform_aug_ann(examples, image_processor, transform):
image_ids = examples["image_id"]
images, bboxes, area, categories = [], [], [], []
for image, objects in zip(examples["image"], examples["objects"]):
image = np.array(image.convert("RGB"))[:, :, ::-1]
out = transform(image=image, bboxes=objects["bbox"], category=objects["category"])
area.append(objects["area"])
images.append(out["image"])
bboxes.append(out["bboxes"])
categories.append(out["category"])
targets = [
{"image_id": id_, "annotations": format_annotations(id_, cat_, ar_, box_)}
for id_, cat_, ar_, box_ in zip(image_ids, categories, area, bboxes)
]
return image_processor(images=images, annotations=targets, return_tensors="pt")
and
def create_transform(width, height):
train_transform = albumentations.Compose(
[
albumentations.SmallestMaxSize(max_size=1000, p=1),
albumentations.RandomScale(p=0.5, scale_limit=(-0.4, 0.8)),
],
bbox_params=albumentations.BboxParams(
format="coco", label_fields=["category"], min_area=400, min_visibility=0.7
),
)
test_transform = albumentations.Compose(
[
albumentations.Resize(width, height),
],
bbox_params=albumentations.BboxParams(format="coco", label_fields=["category"]),
)
return train_transform, test_transform
Unfortunately I havenât run this against the CPPE5 (or equivalent) dataset, I ended up creating a custom one for a personal project, but it mimicked the structure of CPPE5 and my data loader is just load_dataset(str(module_directory / "my_dataset"), name=config)
. If you can share your project I can see if I can find the source of the problem.