I’m trying to fine-tunning the VAE of SD 1.4
I’m in a multi gpu environment, and I’m using accelerate
library for handling that.
This is my code summarized:
import os
import torch.nn.functional as F
import yaml
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from diffusers import AutoencoderKL
from torch.optim import Adam
from accelerate import Accelerator
from torch.utils.tensorboard import SummaryWriter
# Load configuration
with open('config.yaml', 'r') as file:
config = yaml.safe_load(file)
def save_checkpoint(model, optimizer, epoch, step, filename="checkpoint.pth.tar"):
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'step': step
}
torch.save(checkpoint, filename)
class ImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.images = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.png')]
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image
# Setup dataset and dataloader based on config
transform = Compose([
Resize((config['dataset']['image_size'], config['dataset']['image_size'])),
ToTensor(),
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
dataset = ImageDataset(root_dir=config['dataset']['root_dir'], transform=transform)
dataloader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=True, num_workers=config['training']['num_workers'])
# Initialize model, accelerator, optimizer, and TensorBoard writer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = config['model']['path']
vae = AutoencoderKL.from_pretrained(model_path).to(device)
optimizer = Adam(vae.parameters(), lr=config['training']['learning_rate'])
accelerator = Accelerator()
vae, dataloader = accelerator.prepare(vae, dataloader)
writer = SummaryWriter()
# Training loop
for epoch in range(config['training']['num_epochs']):
vae.train()
total_loss = 0
for step, batch in enumerate(dataloader):
with accelerator.accumulate(vae):
# Assuming the first element of the batch is the image
target = batch[0].to(next(vae.parameters()).dtype)
# Access the original model for custom methods
model = vae.module if hasattr(vae, "module") else vae
posterior = model.encode(target).latent_dist
z = posterior.mode()
pred = model.decode(z).sample
kl_loss = posterior.kl().mean()
mse_loss = F.mse_loss(pred, target, reduction="mean")
loss = mse_loss + config['training']["kl_scale"] * kl_loss
optimizer.zero_grad()
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad() # Clear gradients after updating weights
# Checkpointing every 10 steps
if step % 10 == 0:
checkpoint_path = f"checkpoint_epoch_{epoch}_step_{step}.pth"
accelerator.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss,
}, checkpoint_path)
print(f"Checkpoint saved to {checkpoint_path}")
writer.close()
print("Training complete.")
When running the code, I got the following error:
RuntimeError: Expected weight to be a vector of size equal to the number of channels in input, but got weight of shape [128] and input of shape [128, 1024, 1024]:
My input folder contains a set of png images with different sizes, and resized to 1024x1024 in the configuration file.
I do not know why this is happening and if someone knows, or if there is a easier way to fine-tunning the VAE weights using my images.
Thanks.
Edit:
My config.yaml
file
model:
path: 'vae1dot4' # Path to your pre-trained model directory
dataset:
root_dir: 'segmented' # Directory containing your PNG images
image_size: 1024 # Target size for image resizing
training:
batch_size: 8 # Batch size for training
num_epochs: 10 # Number of epochs to train
learning_rate: 0.0005 # Learning rate for the optimizer
num_workers: 4 # Number of worker processes for data loading
kl_scale: 1
gradient_accumulation_steps: 1
logging:
tensorboard_dir: 'runs' # Directory for TensorBoard logs