Loss not Decreasing: Hiera MAE Pretraining from Scratch

I have been attempting to conduct MAE pretraining on a new Hiera model with my own dataset of 3D grayscale images of size 6x256x256 pixels. However, I am still somewhat new to Pytorch and have been running into issues.

I initially wrote my own training script just relying on the code in the official Hiera Github. However, the loss decreased over the first few iterations and then did not change for a dozen epochs or so. Figuring the problem lay in my code, I switched to use the HieraForPreTraining class from the transformers library. I modified it to support 3D images, using a lot of the code from the Hiera Github. However, I encountered exactly the same issue. The loss decreased over a handful of iterations from a little over 1, to be stuck right at .9951. I have also noticed that if left to run, even while the loss seems to be stuck, the gradient norm will decrease until it eventually starts getting reported as NaN.

I have attached my short training script using the transformers library as it shows the parameters and image transformations I am using, although I tried to stick to the training parameters shown in the Hiera paper for the video models. I also show the Hiera configuration JSON file. The main difference is that my batch size has to be much smaller than what is recommended in the paper due to resource constraints although I attempt to compensate for this via gradient accumulation.

{
	"embed_dim": 144,
	"image_size": [6, 256, 256],
	"patch_size": [3, 7, 7],
	"patch_stride": [2, 4, 4],
	"patch_padding": [1, 3, 3],
	"num_heads": [2, 4, 8, 16],
	"num_query_pool": 2,
	"query_stride": [1, 2, 2],
	"masked_unit_size": [1, 8, 8],
	"drop_path_rate": 0.2,
	"num_channels": 1,
	"decoder_hidden_size": 512,
	"decoder_depth": 8,
	"decoder_num_heads": 16,
	"mask_ratio": 0.75,
	"out_indices": [0, 1, 2, 3]
}
class ThreeDimColorJitter(torch.nn.Module):
    def __init__(self, brightness, contrast):
        super().__init__()
        self.brightness = brightness
        self.contrast = contrast

    def __call__(self, image):
        color_jitter = v2.ColorJitter(brightness=self.brightness, contrast=self.contrast)
        _, brightness_factor, contrast_factor, _, _ = color_jitter.get_params(
        	color_jitter.brightness, color_jitter.contrast,
        	saturation=None, hue=None
        )
        
        trans_image = torch.empty(image.shape[1:], dtype=image.dtype)
        for img_ind in range(image.shape[1]):
        	image_2d = image[0, img_ind, :, :].unsqueeze(0)
        	image_2d = F.adjust_brightness(image_2d, brightness_factor)
        	image_2d = F.adjust_contrast(image_2d, contrast_factor)
        	
        	trans_image[img_ind, :, :] = image_2d
        	
        return trans_image.unsqueeze(0)

class CustomImageDataset(Dataset):
    def __init__(self, image_paths, transforms=None):
        self.image_paths = image_paths  # List of file paths to images
        self.transforms = transforms

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]

        image = tifffile.imread(image_path)  # Shape: (6, 512, 512)
        image = torch.from_numpy(image)  # Convert to torch tensor
        image = image.unsqueeze(0)
        
        # Apply any transforms
        if self.transforms:
            image = self.transforms(image)
            
        return image
        
def collate_fn(examples):
    pixel_values = torch.stack([example for example in examples])
    return {"pixel_values": pixel_values}

def get_image_paths(directories):
	image_paths = []
	for data_dir in directories:
		for filename in sorted(os.listdir(data_dir)):
			if filename.endswith('.tiff'):
				image_paths.append(os.path.join(data_dir, filename))
	return image_paths

def main():
    # Paths and hyperparameters
    data_dirs = [
    	'./images_1',
    	'./images_2',
    	'./images_3',
    	'./images_4'
    ]
    train_eval_ratio = 0.8
    image_paths = get_image_paths(data_dirs)
    
    batch_size = 8  # Adjust based on your GPU memory
    epochs = 800	# Adjust based on training results
    learning_rate = 1.6e-3

    # Data transformations
    transforms = v2.Compose([
        v2.RandomHorizontalFlip(p=0.5),
    	v2.RandomVerticalFlip(p=0.5),
    	ThreeDimColorJitter(brightness=0.2, contrast=0.2),
    	v2.ToDtype(torch.float32, scale=True), # Converts to float and scales values to be between [0, 1]
    ])

    # Create Datasets
    random.shuffle(image_paths)
    split_index = int(len(image_paths) * train_eval_ratio)
    
    train_dataset = CustomImageDataset(image_paths[:split_index], transforms=transforms)
    eval_dataset = CustomImageDataset(image_paths[split_index:], transforms=transforms)
    
    config = HieraConfig.from_json_file("hiera_config.json")
    model = HieraForPreTraining(config)
    
    # Training arguments
    training_args = TrainingArguments(
        output_dir='./models/',
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        learning_rate=learning_rate,
        fp16=True,  # Use mixed precision training
        gradient_accumulation_steps=64,
        max_grad_norm=.02, # For gradient clipping
        adam_beta1=0.9,
        adam_beta2=0.95,
        weight_decay=.05,
        lr_scheduler_type="cosine",
        warmup_steps=120,
        save_strategy="epoch",
        logging_steps=10,
        dataloader_num_workers=8,
        ddp_backend="nccl",
        ddp_find_unused_parameters=False
    )

    # Initialize Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=collate_fn,
    )

    # Start training
    trainer.train()

if __name__ == '__main__':
    main()
1 Like