I have been attempting to conduct MAE pretraining on a new Hiera model with my own dataset of 3D grayscale images of size 6x256x256 pixels. However, I am still somewhat new to Pytorch and have been running into issues.
I initially wrote my own training script just relying on the code in the official Hiera Github. However, the loss decreased over the first few iterations and then did not change for a dozen epochs or so. Figuring the problem lay in my code, I switched to use the HieraForPreTraining class from the transformers library. I modified it to support 3D images, using a lot of the code from the Hiera Github. However, I encountered exactly the same issue. The loss decreased over a handful of iterations from a little over 1, to be stuck right at .9951. I have also noticed that if left to run, even while the loss seems to be stuck, the gradient norm will decrease until it eventually starts getting reported as NaN.
I have attached my short training script using the transformers library as it shows the parameters and image transformations I am using, although I tried to stick to the training parameters shown in the Hiera paper for the video models. I also show the Hiera configuration JSON file. The main difference is that my batch size has to be much smaller than what is recommended in the paper due to resource constraints although I attempt to compensate for this via gradient accumulation.
{
"embed_dim": 144,
"image_size": [6, 256, 256],
"patch_size": [3, 7, 7],
"patch_stride": [2, 4, 4],
"patch_padding": [1, 3, 3],
"num_heads": [2, 4, 8, 16],
"num_query_pool": 2,
"query_stride": [1, 2, 2],
"masked_unit_size": [1, 8, 8],
"drop_path_rate": 0.2,
"num_channels": 1,
"decoder_hidden_size": 512,
"decoder_depth": 8,
"decoder_num_heads": 16,
"mask_ratio": 0.75,
"out_indices": [0, 1, 2, 3]
}
class ThreeDimColorJitter(torch.nn.Module):
def __init__(self, brightness, contrast):
super().__init__()
self.brightness = brightness
self.contrast = contrast
def __call__(self, image):
color_jitter = v2.ColorJitter(brightness=self.brightness, contrast=self.contrast)
_, brightness_factor, contrast_factor, _, _ = color_jitter.get_params(
color_jitter.brightness, color_jitter.contrast,
saturation=None, hue=None
)
trans_image = torch.empty(image.shape[1:], dtype=image.dtype)
for img_ind in range(image.shape[1]):
image_2d = image[0, img_ind, :, :].unsqueeze(0)
image_2d = F.adjust_brightness(image_2d, brightness_factor)
image_2d = F.adjust_contrast(image_2d, contrast_factor)
trans_image[img_ind, :, :] = image_2d
return trans_image.unsqueeze(0)
class CustomImageDataset(Dataset):
def __init__(self, image_paths, transforms=None):
self.image_paths = image_paths # List of file paths to images
self.transforms = transforms
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
image = tifffile.imread(image_path) # Shape: (6, 512, 512)
image = torch.from_numpy(image) # Convert to torch tensor
image = image.unsqueeze(0)
# Apply any transforms
if self.transforms:
image = self.transforms(image)
return image
def collate_fn(examples):
pixel_values = torch.stack([example for example in examples])
return {"pixel_values": pixel_values}
def get_image_paths(directories):
image_paths = []
for data_dir in directories:
for filename in sorted(os.listdir(data_dir)):
if filename.endswith('.tiff'):
image_paths.append(os.path.join(data_dir, filename))
return image_paths
def main():
# Paths and hyperparameters
data_dirs = [
'./images_1',
'./images_2',
'./images_3',
'./images_4'
]
train_eval_ratio = 0.8
image_paths = get_image_paths(data_dirs)
batch_size = 8 # Adjust based on your GPU memory
epochs = 800 # Adjust based on training results
learning_rate = 1.6e-3
# Data transformations
transforms = v2.Compose([
v2.RandomHorizontalFlip(p=0.5),
v2.RandomVerticalFlip(p=0.5),
ThreeDimColorJitter(brightness=0.2, contrast=0.2),
v2.ToDtype(torch.float32, scale=True), # Converts to float and scales values to be between [0, 1]
])
# Create Datasets
random.shuffle(image_paths)
split_index = int(len(image_paths) * train_eval_ratio)
train_dataset = CustomImageDataset(image_paths[:split_index], transforms=transforms)
eval_dataset = CustomImageDataset(image_paths[split_index:], transforms=transforms)
config = HieraConfig.from_json_file("hiera_config.json")
model = HieraForPreTraining(config)
# Training arguments
training_args = TrainingArguments(
output_dir='./models/',
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
learning_rate=learning_rate,
fp16=True, # Use mixed precision training
gradient_accumulation_steps=64,
max_grad_norm=.02, # For gradient clipping
adam_beta1=0.9,
adam_beta2=0.95,
weight_decay=.05,
lr_scheduler_type="cosine",
warmup_steps=120,
save_strategy="epoch",
logging_steps=10,
dataloader_num_workers=8,
ddp_backend="nccl",
ddp_find_unused_parameters=False
)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=collate_fn,
)
# Start training
trainer.train()
if __name__ == '__main__':
main()