Hi everyone,
I’m building a custom Stable Diffusion pipeline to generate 1D signal data instead of images. Here’s what I’ve implemented so far:
Trained a custom autoencoder to encode raw 1D signals into a latent space and decode them back. The latent dimension is 4.
Trained a 1D UNet that takes as input the latent signal, text embedding, and timestep embedding, and predicts noise.
Built a custom pipeline using
DDIMScheduler
to iteratively denoise and decode signals from latent space.
However, the final generated signal is just noise and doesn’t resemble the training data. I’m not sure whether:
- My latent space isn’t properly learned,
- The UNet training is unstable,
- Or my sampling loop has a mistake.
Has anyone encountered a similar issue when adapting Stable Diffusion to 1D data?
Any help, debugging tips, or pointers would be greatly appreciated!
Thanks in advance
Here’s my sampling pipeline (__call__
):
@torch.no_grad()
def call(
self,
prompt,
signal_length=None,
latent_length=None,
num_inference_steps=50,
guidance_scale=7.5,
generator=None,
latents=None,
return_dict=True,
output_type=“np”,
sigmas: List[float] = None,
timesteps: List[int] = None,
):
device = self._execution_device
batch_size = 1 if isinstance(prompt, str) else len(prompt)
self._interrupt = False
# 1. Determine latent/signal length
if signal_length is None:
latent_length = getattr(self.vae.config, "latent_channels", 64)
signal_length = getattr(self.vae.config, "input_length", 500)
else:
latent_length = latent_length
signal_length = signal_length
# 2. Encode prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, device, guidance_scale=guidance_scale
)
if guidance_scale > 1.0:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 3. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
# 4. Prepare latents
if latents is None:
latents = randn_tensor(
(batch_size, self.unet.config["in_channels"], latent_length),
generator=generator,
device=device,
dtype=prompt_embeds.dtype
)
else:
latents = latents.to(device)
latents *= self.scheduler.init_noise_sigma
# 5. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
t_batch = torch.full((batch_size,), t, dtype=torch.long, device=t.device)
noise_pred = self.unet(
sample=latent_model_input,
timestep=t_batch,
text_embedding=prompt_embeds
).sample
if guidance_scale > 1.0:
noise_uncond, noise_text = noise_pred.chunk(2)
noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# 6. Decode
self.vae.eval()
signal = self.vae.decode(latents)[0]
if output_type == "np":
signal = signal.squeeze(1).detach().cpu().numpy()
return {"signals": signal} if return_dict else signal