import os
import torch
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import numpy as np
import matplotlib.pyplot as plt
class CropDiseaseDataset(Dataset):
def init(self, df, img_dir, transforms=None):
self.df = df.reset_index(drop=True) # Reset index to handle indexing correctly
self.img_dir = img_dir
self.transforms = transforms # Augmentations should be passed during dataset instantiation
self.file_names = self.df[‘Image_ID’].values # Assuming ‘file_names’ column for image paths
self.targets = self.df[‘class’].values # Assuming ‘class’ column for labels
self.bboxes = self.df[[‘xmin’, ‘ymin’, ‘xmax’, ‘ymax’]].values # Assuming columns for bounding boxes
def __len__(self):
return len(self.df)
def __getitem__(self, index):
# Construct image path
img_path = os.path.join(self.img_dir, self.file_names[index])
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert BGR (OpenCV) to RGB format
target = self.targets[index]
bbox = self.bboxes[index].astype(float) # Convert bbox to float
# Check if bbox is valid
if bbox.shape[0] != 4 or any(bbox < 0):
print(f"Warning: Invalid bbox format at index {index}. Skipping sample.")
return None
# Apply transformations if they exist
if self.transforms is not None:
transformed = self.transforms(image=img, bboxes=[bbox], labels=[target])
# If transformation removes bounding box or any issue, skip the sample
if len(transformed['bboxes']) == 0:
print(f"Warning: Transformation removed bbox at index {index}. Skipping sample.")
return None
img = transformed['image']
bbox = transformed['bboxes'][0] # Extract transformed bbox as single bbox
target = transformed['labels'][0]
return {
'image': img,
'bbox': bbox,
'label': target
}
Define the transformations
train_transform = A.Compose([
A.RandomResizedCrop(224, 224),
A.HorizontalFlip(p=0.5),
A.RandomGamma(gamma_limit=(80, 120), p=0.5),
A.RandomBrightnessContrast(p=0.5),
A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=0.5),
A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
], bbox_params=A.BboxParams(format=‘pascal_voc’, min_visibility=0.05, label_fields=[‘labels’]))
val_transform = A.Compose([
A.Resize(256, 256),
A.CenterCrop(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
], bbox_params=A.BboxParams(format=‘pascal_voc’, min_visibility=0.05, label_fields=[‘labels’]))
def visualize_dataset_samples(dataset, num_samples=5):
“”"
Visualize a set number of samples from the dataset to check for issues.
“”"
for idx in range(num_samples):
print(f"\nVisualizing sample {idx}:“)
try:
sample = dataset[idx]
if sample is None:
print(f"Sample {idx} is None. Skipping.”)
continue
img = sample['image'].permute(1, 2, 0).numpy() # Convert to HWC format for display
bbox = sample['bbox']
label = sample['label']
# Draw bounding box on the image
x_min, y_min, x_max, y_max = map(int, bbox)
img = cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (255, 0, 0), 2) # Draw bounding box
plt.imshow(img)
plt.title(f"Label: {label}")
plt.axis('off')
plt.show()
except IndexError as e:
print(f"Error visualizing sample {idx}: {e}")
except Exception as e:
print(f"An unexpected error occurred for sample {idx}: {e}")
Usage example (assuming your DataFrame df
and image directory IMAGE_DIR
are set up)
dataset = CropDiseaseDataset(df=df, img_dir=IMAGE_DIR, transforms=train_transform)
visualize_dataset_samples(dataset, num_samples=5)
The error: Visualizing sample 1:
Warning: Transformation removed bbox at index 1. Skipping sample.
Sample 1 is None. Skipping.
it reads some images and also show this error in some too?