I stocked at DDPM model.
I driveled all the formula step by step, and implement a simple model using PyTorch. But the results shows something wrong. The output of my model fit the noise well, at least I think so, but when comes to reverse process which is restore x0 image, there is only noisy left.
Here is what I do in my simple DDPM model:
- The goal of UNET is predict noise given xt and t, and the loss function is MSE;
- For the test step, I first compute x0 using xt and noise_hat, then get mean and variance of x-1 iteratively.
Main code of my train implement:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
writer = SummaryWriter(log_dir=str(LOG_PATH))
best_loss = 1e10
patience = 0
progress_bar = tqdm(total=EPOCHS * EPOCH_STEPS)
for epoch_step in range(EPOCHS * EPOCH_STEPS):
optimizer.zero_grad()
xt, t, noisy = noisy_sample(image, ᾱ, num_sample=BATCHS)
xt = xt.to(DEVICE)
t = t.to(DEVICE)
noisy = noisy.to(DEVICE)
noisy_hat = model(xt, t)
loss = criterion(noisy, noisy_hat["sample"])
loss.backward()
optimizer.step()
writer.add_scalar('train_loss', loss.item(), epoch_step)
if loss.item() < best_loss:
best_loss = loss.item()
torch.save(model.state_dict(), MODEL_PATH/'best_ddpm_model.pth')
# patience = 0
# else:
# patience += 1
# if patience > PATIENCE:
# print(f'Early stop at step {epoch_step}')
# break
progress_bar.update(1)
writer.close()
torch.save(model.state_dict(), MODEL_PATH/'model.pth')
print(f'Finall loss:{loss.item()}')
Main code of my test implement:
model.eval()
with torch.no_grad():
# step by step reverse sample
for i in tqdm(reversed(range(TEST_STEP)), total=TEST_STEP):
t = torch.tensor(i, dtype=torch.long).view(-1, 1).to(DEVICE)
# get noisy_hat add to xt
noisy_hat = model(xt, t)['sample']
# use noisy_hat predict by unet calculate x0_hat
x0_hat = (1/ᾱ[t]).sqrt() * xt - ((1-ᾱ[t])/ᾱ[t]).sqrt()*noisy_hat
# use x0_hat and xt calculate x-1 mean and variance
variance_t_1 = (1-𝛂[t]) * (1-ᾱ[t-1]) / (1-ᾱ[t]).view(1,1,1,1)
mean_t_1 = 𝛂[t].sqrt() * (1-ᾱ[t-1])/(1-ᾱ[t]) * xt + ᾱ[t-1].sqrt() * (1-𝛂[t])/(1-ᾱ[t]) * x0_hat
# sample x-1
noise = torch.normal(0, 1, xt.shape).to(DEVICE)
noisy_mask = 1 if i!=0 else 0
noisy_mask = torch.tensor(noisy_mask).to(DEVICE)
xt = mean_t_1 + noisy_mask * variance_t_1.sqrt() * noise
Since I didn’t dive into UNET, I think maybe I use the wrong parameters? Or I use the wrong formula, which I don’t think so.
Here is the training, testing notebook, as well as the model already trained.