VQModel usage issues

I’m trying to use the VQModel that is part of huggingface. Is this implementation correct? I see some artifacts post reconstruction where the images don’t look like they are normalized. Also, is the loss computed correctly (given the change from the legacy implementation)?

The data requirements seem huge for a 256*256 image. Is there something obviously wrong here?

import torch
from torch import nn
import torchvision
from torchvision.io import read_image
import torchvision.transforms as transforms
from dataclasses import dataclass
from datasets import load_dataset
from diffusers import VQModel
import torch.nn.functional as F
from diffusers.utils import make_image_grid
import math
import os
from accelerate import Accelerator
from huggingface_hub import HfFolder, Repository, whoami
from tqdm.auto import tqdm
from PIL import Image
import matplotlib.pyplot as plt
from accelerate import notebook_launcher
import glob

torch.manual_seed(42)
assert(torch.cuda.is_available()==True)
class TrainingConfig:
    image_size = 256  # the generated image resolution
    train_batch_size = 8
    eval_batch_size = 2  # how many images to sample during evaluation
    num_epochs = 10000
    gradient_accumulation_steps = 1
    learning_rate = 1e-6
    save_image_epochs = 1
    save_model_epochs = 5
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "vae_results"  # the model name locally and on the HF Hub
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 42  # random seed for training/validation/test splits
config = TrainingConfig()

visual_preprocess_for_visualization = transforms.Compose(
    [
        transforms.ToPILImage("RGB"),
        transforms.Resize((config.image_size, config.image_size)),
    ]
)

preprocess = transforms.Compose(
    [
        transforms.Resize((config.image_size, config.image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}

dataset, test_dataset = load_dataset("imagefolder", data_dir='/home/aditya/Datasets', split=['train', 'test'])
dataset.set_transform(transform)
test_dataset.set_transform(transform)

image_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=6, pin_memory=True)
image_test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.eval_batch_size, shuffle=True, num_workers=4, pin_memory=True)

model = VQModel(in_channels=3, out_channels=3, latent_channels=1, down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], block_out_channels=[64, 64, 64], layers_per_block = 3, num_vq_embeddings = 8192, norm_type = "group").cuda()

# model.enable_gradient_checkpointing()

optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)

def evaluate(config, epoch, pipeline, dataloader):
    for step, batch in enumerate(dataloader):
        curr_images  = batch["images"]
        pred_image = pipeline.module.forward(curr_images.cuda())
        images = []
        for index in range(len(curr_images)):        
            images.append(visual_preprocess_for_visualization(curr_images[index].cpu()))
            images.append(visual_preprocess_for_visualization(pred_image[0][index]))
        # Make a grid out of the images
        image_grid = make_image_grid(images, rows=len(images), cols=1)

        # Save the images
        test_dir = os.path.join(config.output_dir, "samples")
        os.makedirs(test_dir, exist_ok=True)
        image_grid.save(f"{test_dir}/{epoch:04d}.png")

def train_loop(config, model, optimizer, train_dataloader, test_dataloader):
    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs"),
    )
    if accelerator.is_main_process:
        os.makedirs(config.output_dir, exist_ok=True)
        accelerator.init_trackers("train_example")

    # Prepare everything
    # There is no specific order to remember, you just need to unpack the
    # objects in the same order you gave them to the prepare method.
    model, optimizer, train_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader
    )

    model.module.quantize.legacy = False
    global_step = 0

    # Now you train the model
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            clean_images = batch["images"]

            with accelerator.accumulate(model):
                # Predict the noise residual
                h = model.module.encode(clean_images).latents
                z_q, codebook_commitment_loss, (perplexity, min_encodings, min_encoding_indices) = model.module.quantize.forward(h)
                pred_image = model.module.decode(z_q).sample
                loss = F.mse_loss(clean_images, pred_image) + codebook_commitment_loss
                loss = loss /config.gradient_accumulation_steps
                accelerator.backward(loss)
                # accelerator.clip_grad_norm_(model.module.parameters(), 1.0)
                if step % config.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1
        # After each epoch you optionally sample some demo images with evaluate() and save the model
        if accelerator.is_main_process:
            if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
                evaluate(config, epoch, model, test_dataloader)

            if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
                model.module.save_pretrained(config.output_dir)

args = (config, model, optimizer, image_dataloader, image_test_dataloader)
notebook_launcher(train_loop, args, num_nodes=2, num_processes=1)