I’m working with a custom diffusion model and it’s having issues training. Loss drops quickly at first, then plateaus indefinitely.
The goal of the model is to map between paired embedding spaces (x,y)
. Given an embedding x
, predict paired embedding y
.
Currently I have a MLP model that does this prediction directly. I want to see if I can get better performance using a diffusion model, similar to the diffusion prior in DALLE 2.
I have a version of this that sort of works but has very poor performance compared to the straight MLP version. I don’t work much with diffusion models, so my current assumption is I’m doing something wrong. Does anything look clearly out of place about the code below?
model
The model is set up to take in [x, noised_y, timestep]
. Timestep integers are converted to an embedding. The model concatenates [x, noised_y, timestep_embedding]
and sends this through a series of linear layers to generate a noise prediction.
Does the predict
function for diffusion inference look right?
class DiffusionMLP(nn.Module):
def __init__(self,
n_timesteps, # number of diffusion timesteps
d_embedding, # dimension of timestep embedding
d_in, # input embedding dimension
d_hidden # hidden layer dimension
):
super().__init__()
self.noise_scheduler = DDPMScheduler(num_train_timesteps=n_timesteps,
beta_schedule="sigmoid",
beta_start=1e-7,
beta_end=2e-3,
clip_sample=False)
self.time_embedding = nn.Embedding(n_timesteps, d_embedding)
self.layers = nn.Sequential(
nn.Linear(d_in*2+d_embedding, d_hidden),
nn.ReLU(),
nn.Linear(d_hidden, d_hidden),
nn.ReLU(),
nn.Linear(d_hidden, d_hidden),
nn.ReLU(),
nn.Linear(d_hidden, d_hidden),
nn.ReLU(),
nn.Linear(d_hidden, d_hidden),
nn.ReLU(),
nn.Linear(d_hidden, d_in),
)
def forward(self, x, noisy_y, timesteps):
# convert timestep integer to embedding
timesteps = self.time_embedding(timesteps)
# concat inputs
x = torch.cat([x, noisy_y, timesteps], -1)
# predict noise
noise_prediction = self.layers(x)
return noise_prediction
def predict(self, x):
# create starting noise
noisy_y = torch.randn(x.shape, device=x.device)
# iterate over diffusion steps
for t in self.noise_scheduler.timesteps:
# generate timestep longtensor
timesteps = torch.zeros(x.shape[0], device=x.device).long() + t
# predict noise
with torch.no_grad():
noise_prediction = self.forward(x, noisy_y, timesteps)
# denoise to previous sample
noisy_y = self.noise_scheduler.step(noise_prediction, t, noisy_y).prev_sample
return noisy_y
training code
# adam optimizer
opt = optim.AdamW(model.parameters(), lr=lr)
# cosine LF schedule
scheduler = get_cosine_schedule_with_warmup(
optimizer=opt,
num_warmup_steps=1000,
num_training_steps=(len(train_dataloader) * epochs),
)
# train loop
for epoch in range(epochs):
for batch in tqdm(train_dataloader):
x,y = batch # both x,y are of shape (n, d)
x = x.to(DEVICE)
y = y.to(DEVICE)
bs = x.shape[0]
# sample random noise
noise = torch.randn(y.shape, device=y.device)
# sample timesteps
timesteps = torch.randint(0, mapper.noise_scheduler.config.num_train_timesteps, (bs,),
device=y.device).long()
# add noise to y embedding
noisy_y = model.noise_scheduler.add_noise(y, noise, timesteps)
# predict noise
noise_pred = model(x, noisy_y, timesteps)
# mse loss between noise prediction and input noise
loss = F.mse_loss(noise_pred, noise)
# optimizer step
opt.zero_grad()
loss.backward()
opt.step()
scheduler.step()
As a validation metric, I’m looking at cosine similarity between the ground truth y
embeddings and the outputs generated by model.predict(x)
. After a lot of training, the diffusion outputs have a cosine similarity of ~0.77 to the ground truth, compared to ~0.92 for the baseline model. Training loss (MSE on noise prediction) is stuck at around 0.33.
I’ve noticed that as noise prediction MSE loss goes down, cosine similarity of reconstructions goes up (ie gets worse). This makes me think there might be a bug in the predict
code but I haven’t found it yet.
I’ve also noticed that noise prediction loss correlates strongly with timestep. This is a plot of nose prediction MSE loss as a function of timestep. I’m not sure if this is typical for diffusion models or if it points to an issue.
Reconstructions from model.predict
also have way higher magnitude compared to target vectors. I know the scheduler has a clip parameter, but that feels like a bandaid over a more fundamental issue.
Appreciate anyone pointing out mistakes I’m making wrt setting up the problem or applying the diffusion scheduler.