Need help for ddmp model

I stocked at DDPM model.
I driveled all the formula step by step, and implement a simple model using PyTorch. But the results shows something wrong. The output of my model fit the noise well, at least I think so, but when comes to reverse process which is restore x0 image, there is only noisy left.

Here is what I do in my simple DDPM model:

  1. The goal of UNET is predict noise given xt and t, and the loss function is MSE;
  2. For the test step, I first compute x0 using xt and noise_hat, then get mean and variance of x-1 iteratively.

Main code of my train implement:

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    writer = SummaryWriter(log_dir=str(LOG_PATH))
    best_loss = 1e10
    patience = 0

    progress_bar = tqdm(total=EPOCHS * EPOCH_STEPS)

    for epoch_step in range(EPOCHS * EPOCH_STEPS):
        optimizer.zero_grad()
        xt, t, noisy = noisy_sample(image, ᾱ, num_sample=BATCHS)
        xt = xt.to(DEVICE)
        t = t.to(DEVICE)
        noisy = noisy.to(DEVICE)
        noisy_hat = model(xt, t)
        loss = criterion(noisy, noisy_hat["sample"])
        loss.backward()
        optimizer.step()
        writer.add_scalar('train_loss', loss.item(), epoch_step)
        if loss.item() < best_loss:
            best_loss = loss.item()
            torch.save(model.state_dict(), MODEL_PATH/'best_ddpm_model.pth')
            # patience = 0
        # else:
        #     patience += 1
        #     if patience > PATIENCE:
        #         print(f'Early stop at step {epoch_step}')
        #         break
        progress_bar.update(1)
    writer.close()
    torch.save(model.state_dict(), MODEL_PATH/'model.pth')
    print(f'Finall loss:{loss.item()}')

Main code of my test implement:

model.eval()
with torch.no_grad():
    # step by step reverse sample
    for i in tqdm(reversed(range(TEST_STEP)), total=TEST_STEP):
        t = torch.tensor(i, dtype=torch.long).view(-1, 1).to(DEVICE)

        # get noisy_hat add to xt
        noisy_hat = model(xt, t)['sample']

        # use noisy_hat predict by unet calculate x0_hat
        x0_hat = (1/ᾱ[t]).sqrt() * xt - ((1-ᾱ[t])/ᾱ[t]).sqrt()*noisy_hat

        # use x0_hat and xt calculate x-1 mean and variance
        variance_t_1 = (1-𝛂[t]) * (1-ᾱ[t-1]) / (1-ᾱ[t]).view(1,1,1,1)
        mean_t_1 = 𝛂[t].sqrt() * (1-ᾱ[t-1])/(1-ᾱ[t]) * xt + ᾱ[t-1].sqrt() * (1-𝛂[t])/(1-ᾱ[t]) * x0_hat

        # sample x-1
        noise = torch.normal(0, 1, xt.shape).to(DEVICE)
        noisy_mask = 1 if i!=0 else 0
        noisy_mask = torch.tensor(noisy_mask).to(DEVICE)
        xt = mean_t_1 + noisy_mask * variance_t_1.sqrt() * noise

Since I didn’t dive into UNET, I think maybe I use the wrong parameters? Or I use the wrong formula, which I don’t think so.

Here is the training, testing notebook, as well as the model already trained.

To someone maybe face the same problem, I just give the conclusion, Unet problem.

I get proper results after I turn Unet down and up layers up to 5, and also I predict 5 times at each reverse time steps.If I loop 10 times each reverse time steps, it will became a white picture finally.

Main change to model:

from diffusers import UNet2DModel


## using diffusers UNet2DModel, since I dont care about the details of the unet
class my_model(torch.nn.Module):
    def __init__(self):
        super(my_model, self).__init__()
        # in_channels=1, out_channels=1, block_out_channels=(16, 32, 64, 128), norm_num_groups=8
        self.model = UNet2DModel(
            sample_size=16, # 2**(len(block_out_channels) - 1)
            in_channels=1, 
            out_channels=1,
            block_out_channels=(32, 64, 64, 64, 128),
            layers_per_block=2,
            down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
            up_block_types=("UpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D"),
            )
        # self.model = UNet2DModel()

    #  backward xt->x0
    def forward(self, xt, t):
        t = t.squeeze().float()
        noisy_hat = self.model(xt, t)
        return noisy_hat

main change to reverse time steps loop:

**LOOP** = 5

with torch.no_grad():
    # step by step reverse sample
    for i in tqdm(reversed(range(TEST_STEP)), total=TEST_STEP):

        t = torch.tensor(i, dtype=torch.long).view(-1, 1).to(DEVICE)

        **for j in range(LOOP):**
        # get noisy_hat add to xt
            noisy_hat = model(xt, t)['sample']

            # use noisy_hat predict by unet calculate x0_hat
            x0_hat = (1/ᾱ[t]).sqrt() * xt - ((1-ᾱ[t])/ᾱ[t]).sqrt()*noisy_hat
            x0_hat = x0_hat.clamp(-1, 1)

            # use x0_hat and xt calculate x-1 mean and variance
            variance_t_1 = (1-𝛂[t]) * (1-ᾱ[t-1]) / (1-ᾱ[t]).view(1,1,1,1)
            mean_t_1 = 𝛂[t].sqrt() * (1-ᾱ[t-1])/(1-ᾱ[t]) * xt + ᾱ[t-1].sqrt() * (1-𝛂[t])/(1-ᾱ[t]) * x0_hat

            # sample x-1
            noise = torch.normal(0, 1, xt.shape).to(DEVICE)
            noisy_mask = 1 if i!=0 else 0
            noisy_mask = torch.tensor(noisy_mask).to(DEVICE)
            xt = mean_t_1 + noisy_mask * variance_t_1.sqrt() * noise
            if i % 10 == 0 and j == 0:
                x_progressive.append(x0_hat)