I’m trying to use the VQModel that is part of huggingface. Is this implementation correct? I see some artifacts post reconstruction where the images don’t look like they are normalized. Also, is the loss computed correctly (given the change from the legacy implementation)?
The data requirements seem huge for a 256*256 image. Is there something obviously wrong here?
import torch
from torch import nn
import torchvision
from torchvision.io import read_image
import torchvision.transforms as transforms
from dataclasses import dataclass
from datasets import load_dataset
from diffusers import VQModel
import torch.nn.functional as F
from diffusers.utils import make_image_grid
import math
import os
from accelerate import Accelerator
from huggingface_hub import HfFolder, Repository, whoami
from tqdm.auto import tqdm
from PIL import Image
import matplotlib.pyplot as plt
from accelerate import notebook_launcher
import glob
torch.manual_seed(42)
assert(torch.cuda.is_available()==True)
class TrainingConfig:
image_size = 256 # the generated image resolution
train_batch_size = 8
eval_batch_size = 2 # how many images to sample during evaluation
num_epochs = 10000
gradient_accumulation_steps = 1
learning_rate = 1e-6
save_image_epochs = 1
save_model_epochs = 5
mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
output_dir = "vae_results" # the model name locally and on the HF Hub
overwrite_output_dir = True # overwrite the old model when re-running the notebook
seed = 42 # random seed for training/validation/test splits
config = TrainingConfig()
visual_preprocess_for_visualization = transforms.Compose(
[
transforms.ToPILImage("RGB"),
transforms.Resize((config.image_size, config.image_size)),
]
)
preprocess = transforms.Compose(
[
transforms.Resize((config.image_size, config.image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
)
def transform(examples):
images = [preprocess(image.convert("RGB")) for image in examples["image"]]
return {"images": images}
dataset, test_dataset = load_dataset("imagefolder", data_dir='/home/aditya/Datasets', split=['train', 'test'])
dataset.set_transform(transform)
test_dataset.set_transform(transform)
image_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=6, pin_memory=True)
image_test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.eval_batch_size, shuffle=True, num_workers=4, pin_memory=True)
model = VQModel(in_channels=3, out_channels=3, latent_channels=1, down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], block_out_channels=[64, 64, 64], layers_per_block = 3, num_vq_embeddings = 8192, norm_type = "group").cuda()
# model.enable_gradient_checkpointing()
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
def evaluate(config, epoch, pipeline, dataloader):
for step, batch in enumerate(dataloader):
curr_images = batch["images"]
pred_image = pipeline.module.forward(curr_images.cuda())
images = []
for index in range(len(curr_images)):
images.append(visual_preprocess_for_visualization(curr_images[index].cpu()))
images.append(visual_preprocess_for_visualization(pred_image[0][index]))
# Make a grid out of the images
image_grid = make_image_grid(images, rows=len(images), cols=1)
# Save the images
test_dir = os.path.join(config.output_dir, "samples")
os.makedirs(test_dir, exist_ok=True)
image_grid.save(f"{test_dir}/{epoch:04d}.png")
def train_loop(config, model, optimizer, train_dataloader, test_dataloader):
# Initialize accelerator and tensorboard logging
accelerator = Accelerator(
mixed_precision=config.mixed_precision,
gradient_accumulation_steps=config.gradient_accumulation_steps,
log_with="tensorboard",
project_dir=os.path.join(config.output_dir, "logs"),
)
if accelerator.is_main_process:
os.makedirs(config.output_dir, exist_ok=True)
accelerator.init_trackers("train_example")
# Prepare everything
# There is no specific order to remember, you just need to unpack the
# objects in the same order you gave them to the prepare method.
model, optimizer, train_dataloader = accelerator.prepare(
model, optimizer, train_dataloader
)
model.module.quantize.legacy = False
global_step = 0
# Now you train the model
for epoch in range(config.num_epochs):
progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["images"]
with accelerator.accumulate(model):
# Predict the noise residual
h = model.module.encode(clean_images).latents
z_q, codebook_commitment_loss, (perplexity, min_encodings, min_encoding_indices) = model.module.quantize.forward(h)
pred_image = model.module.decode(z_q).sample
loss = F.mse_loss(clean_images, pred_image) + codebook_commitment_loss
loss = loss /config.gradient_accumulation_steps
accelerator.backward(loss)
# accelerator.clip_grad_norm_(model.module.parameters(), 1.0)
if step % config.gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
progress_bar.update(1)
logs = {"loss": loss.detach().item(), "step": global_step}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
global_step += 1
# After each epoch you optionally sample some demo images with evaluate() and save the model
if accelerator.is_main_process:
if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
evaluate(config, epoch, model, test_dataloader)
if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
model.module.save_pretrained(config.output_dir)
args = (config, model, optimizer, image_dataloader, image_test_dataloader)
notebook_launcher(train_loop, args, num_nodes=2, num_processes=1)