Iâ€™m working with a custom diffusion model and itâ€™s having issues training. Loss drops quickly at first, then plateaus indefinitely.

The goal of the model is to map between paired embedding spaces `(x,y)`

. Given an embedding `x`

, predict paired embedding `y`

.

Currently I have a MLP model that does this prediction directly. I want to see if I can get better performance using a diffusion model, similar to the diffusion prior in DALLE 2.

I have a version of this that sort of works but has very poor performance compared to the straight MLP version. I donâ€™t work much with diffusion models, so my current assumption is Iâ€™m doing something wrong. Does anything look clearly out of place about the code below?

## model

The model is set up to take in `[x, noised_y, timestep]`

. Timestep integers are converted to an embedding. The model concatenates `[x, noised_y, timestep_embedding]`

and sends this through a series of linear layers to generate a noise prediction.

Does the `predict`

function for diffusion inference look right?

```
class DiffusionMLP(nn.Module):
def __init__(self,
n_timesteps, # number of diffusion timesteps
d_embedding, # dimension of timestep embedding
d_in, # input embedding dimension
d_hidden # hidden layer dimension
):
super().__init__()
self.noise_scheduler = DDPMScheduler(num_train_timesteps=n_timesteps,
beta_schedule="sigmoid",
beta_start=1e-7,
beta_end=2e-3,
clip_sample=False)
self.time_embedding = nn.Embedding(n_timesteps, d_embedding)
self.layers = nn.Sequential(
nn.Linear(d_in*2+d_embedding, d_hidden),
nn.ReLU(),
nn.Linear(d_hidden, d_hidden),
nn.ReLU(),
nn.Linear(d_hidden, d_hidden),
nn.ReLU(),
nn.Linear(d_hidden, d_hidden),
nn.ReLU(),
nn.Linear(d_hidden, d_hidden),
nn.ReLU(),
nn.Linear(d_hidden, d_in),
)
def forward(self, x, noisy_y, timesteps):
# convert timestep integer to embedding
timesteps = self.time_embedding(timesteps)
# concat inputs
x = torch.cat([x, noisy_y, timesteps], -1)
# predict noise
noise_prediction = self.layers(x)
return noise_prediction
def predict(self, x):
# create starting noise
noisy_y = torch.randn(x.shape, device=x.device)
# iterate over diffusion steps
for t in self.noise_scheduler.timesteps:
# generate timestep longtensor
timesteps = torch.zeros(x.shape[0], device=x.device).long() + t
# predict noise
with torch.no_grad():
noise_prediction = self.forward(x, noisy_y, timesteps)
# denoise to previous sample
noisy_y = self.noise_scheduler.step(noise_prediction, t, noisy_y).prev_sample
return noisy_y
```

## training code

```
# adam optimizer
opt = optim.AdamW(model.parameters(), lr=lr)
# cosine LF schedule
scheduler = get_cosine_schedule_with_warmup(
optimizer=opt,
num_warmup_steps=1000,
num_training_steps=(len(train_dataloader) * epochs),
)
# train loop
for epoch in range(epochs):
for batch in tqdm(train_dataloader):
x,y = batch # both x,y are of shape (n, d)
x = x.to(DEVICE)
y = y.to(DEVICE)
bs = x.shape[0]
# sample random noise
noise = torch.randn(y.shape, device=y.device)
# sample timesteps
timesteps = torch.randint(0, mapper.noise_scheduler.config.num_train_timesteps, (bs,),
device=y.device).long()
# add noise to y embedding
noisy_y = model.noise_scheduler.add_noise(y, noise, timesteps)
# predict noise
noise_pred = model(x, noisy_y, timesteps)
# mse loss between noise prediction and input noise
loss = F.mse_loss(noise_pred, noise)
# optimizer step
opt.zero_grad()
loss.backward()
opt.step()
scheduler.step()
```

As a validation metric, Iâ€™m looking at cosine similarity between the ground truth `y`

embeddings and the outputs generated by `model.predict(x)`

. After a lot of training, the diffusion outputs have a cosine similarity of ~0.77 to the ground truth, compared to ~0.92 for the baseline model. Training loss (MSE on noise prediction) is stuck at around 0.33.

Iâ€™ve noticed that as noise prediction MSE loss goes down, cosine similarity of reconstructions goes up (ie gets worse). This makes me think there might be a bug in the `predict`

code but I havenâ€™t found it yet.

Iâ€™ve also noticed that noise prediction loss correlates strongly with timestep. This is a plot of nose prediction MSE loss as a function of timestep. Iâ€™m not sure if this is typical for diffusion models or if it points to an issue.

Reconstructions from `model.predict`

also have way higher magnitude compared to target vectors. I know the scheduler has a clip parameter, but that feels like a bandaid over a more fundamental issue.

Appreciate anyone pointing out mistakes Iâ€™m making wrt setting up the problem or applying the diffusion scheduler.