Help debugging diffusion model

I’m working with a custom diffusion model and it’s having issues training. Loss drops quickly at first, then plateaus indefinitely.

The goal of the model is to map between paired embedding spaces (x,y). Given an embedding x, predict paired embedding y.

Currently I have a MLP model that does this prediction directly. I want to see if I can get better performance using a diffusion model, similar to the diffusion prior in DALLE 2.

I have a version of this that sort of works but has very poor performance compared to the straight MLP version. I don’t work much with diffusion models, so my current assumption is I’m doing something wrong. Does anything look clearly out of place about the code below?

model

The model is set up to take in [x, noised_y, timestep]. Timestep integers are converted to an embedding. The model concatenates [x, noised_y, timestep_embedding] and sends this through a series of linear layers to generate a noise prediction.

Does the predict function for diffusion inference look right?

class DiffusionMLP(nn.Module):
    def __init__(self, 
                 n_timesteps, # number of diffusion timesteps
                 d_embedding, # dimension of timestep embedding
                 d_in,        # input embedding dimension
                 d_hidden     # hidden layer dimension
                ):
        super().__init__()
        
        self.noise_scheduler = DDPMScheduler(num_train_timesteps=n_timesteps,
                                                         beta_schedule="sigmoid",
                                                         beta_start=1e-7, 
                                                         beta_end=2e-3, 
                                                         clip_sample=False)
        
        self.time_embedding = nn.Embedding(n_timesteps, d_embedding)
        
        self.layers = nn.Sequential(
                                    nn.Linear(d_in*2+d_embedding, d_hidden),
                                    nn.ReLU(),
                                    nn.Linear(d_hidden, d_hidden),
                                    nn.ReLU(),
                                    nn.Linear(d_hidden, d_hidden),
                                    nn.ReLU(),
                                    nn.Linear(d_hidden, d_hidden),
                                    nn.ReLU(),
                                    nn.Linear(d_hidden, d_hidden),
                                    nn.ReLU(),
                                    nn.Linear(d_hidden, d_in),
                                    )
        
    def forward(self, x, noisy_y, timesteps):
        # convert timestep integer to embedding
        timesteps = self.time_embedding(timesteps)
        
        # concat inputs
        x = torch.cat([x, noisy_y, timesteps], -1)
        
        # predict noise
        noise_prediction = self.layers(x)
        return noise_prediction
    
    def predict(self, x):
        
        # create starting noise
        noisy_y = torch.randn(x.shape, device=x.device)
        
        # iterate over diffusion steps
        for t in self.noise_scheduler.timesteps:
            # generate timestep longtensor 
            timesteps = torch.zeros(x.shape[0], device=x.device).long() + t
            
            # predict noise
            with torch.no_grad():
                noise_prediction = self.forward(x, noisy_y, timesteps)
                
            # denoise to previous sample
            noisy_y = self.noise_scheduler.step(noise_prediction, t, noisy_y).prev_sample
            
        return noisy_y

training code

# adam optimizer
opt = optim.AdamW(model.parameters(), lr=lr)

# cosine LF schedule
scheduler = get_cosine_schedule_with_warmup(
    optimizer=opt,
    num_warmup_steps=1000,
    num_training_steps=(len(train_dataloader) * epochs),
)

# train loop

for epoch in range(epochs):
    for batch in tqdm(train_dataloader):
        x,y = batch # both x,y are of shape (n, d)
        
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        
        bs = x.shape[0]

        # sample random noise
        noise = torch.randn(y.shape, device=y.device)

        # sample timesteps
        timesteps = torch.randint(0, mapper.noise_scheduler.config.num_train_timesteps, (bs,), 
                                  device=y.device).long()

        # add noise to y embedding
        noisy_y = model.noise_scheduler.add_noise(y, noise, timesteps)
                
        # predict noise
        noise_pred = model(x, noisy_y, timesteps)
                
        # mse loss between noise prediction and input noise
        loss = F.mse_loss(noise_pred, noise)

        # optimizer step
        opt.zero_grad()
        loss.backward()
        opt.step()
        scheduler.step()

As a validation metric, I’m looking at cosine similarity between the ground truth y embeddings and the outputs generated by model.predict(x). After a lot of training, the diffusion outputs have a cosine similarity of ~0.77 to the ground truth, compared to ~0.92 for the baseline model. Training loss (MSE on noise prediction) is stuck at around 0.33.

I’ve noticed that as noise prediction MSE loss goes down, cosine similarity of reconstructions goes up (ie gets worse). This makes me think there might be a bug in the predict code but I haven’t found it yet.

I’ve also noticed that noise prediction loss correlates strongly with timestep. This is a plot of nose prediction MSE loss as a function of timestep. I’m not sure if this is typical for diffusion models or if it points to an issue.

download

Reconstructions from model.predict also have way higher magnitude compared to target vectors. I know the scheduler has a clip parameter, but that feels like a bandaid over a more fundamental issue.

Appreciate anyone pointing out mistakes I’m making wrt setting up the problem or applying the diffusion scheduler.