Custom pipeline inference speed extremely slow

Hi, I have been using Diffusers recently and was able to create a custom inpainting model. The model is based off the Palette repo: GitHub - Janspiry/Palette-Image-to-Image-Diffusion-Models: Unofficial implementation of Palette: Image-to-Image Diffusion Models by Pytorch. My model is able to take an image, mask, and label as input and inpaint the masked region with that label.

In the Palette implementation, they concatenate the ground truth image with a conditional image that contains masked regions with noise from a given mask. This makes the input to the UNet 6 channels instead of the usual 3. However, their model does not have any class conditioning while I noticed the UNet2DModel in Diffusers does. Because of this I went ahead and created a “PalettePipeline” under Diffusers. All I really had to change was concatenating the two model inputs into one so the UNet2DModel had 6 input channels instead of 3 as well as putting a step to mask the image at the end of each timestep in the base DiffusionPipeline. The only change I made to the UNet2DModel was the input channels. Everything else is kept the same/the default parameters

After training this model, I was able to achieve some pretty good results. However, I noticed that the inference speed for this model is very very slow. These are my experiments:

  1. An image of 512x512 using half precision, DDIM with 100 timesteps, on an NVIDIA RTX A6000 takes about 3 minutes to sample in the custom model. The tqdm progress bar displayed that the sampling speed was about 2 iters/sec.

  2. As a sanity check, I used an off-the-shelf Stable Diffusion Inpainting model (runwayml/stable-diffusion-inpainting · Hugging Face) using the same parameters, and inputs. This model only took about 3 seconds and sampled at about 26 iters/sec.

  3. Finally, I experimented directly in the Palette repo but without any class conditioning. Using the same parameters, inference only took about 2 seconds and sampled at 34 iters/sec.

I have tried looking into why my custom pipeline is so slow and tried a few things but none have helped. I have tried things like setting enable_attention_slicing(), enable_cpu_offload(), but neither of these changed the inference speed. I am also 100% sure that the model running on the GPU as I checked the devices of all the inputs as well as checking through nvidia-smi. To my knowledge, the only thing that is different is the number of channels in the UNet2DModel as well as a masking step at the end of each timestep. This masking step appears in the Palette repo and based on the results above (3) it does not cause any drastic slowdowns.

If anyone has any solutions to try and fix the super slow inference for my custom pipeline they would be much appreciated, thank you!

I was able to do some profiling and found that the problem is coming from the UNet’s forward function. I have the results of profiling listed below (done at 256x256 instead of 512x512):

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   244                                               @profile
   245                                               def forward(
   246                                                   self,
   247                                                   sample: torch.FloatTensor,
   248                                                   timestep: Union[torch.Tensor, float, int],
   249                                                   class_labels: Optional[torch.Tensor] = None,
   250                                                   return_dict: bool = True,
   251                                               ) -> Union[UNet2DOutput, Tuple]:
   252                                                   r"""
   253                                                   The [`UNet2DModel`] forward method.
   254                                           
   255                                                   Args:
   256                                                       sample (`torch.FloatTensor`):
   257                                                           The noisy input tensor with the following shape `(batch, channel, height, width)`.
   258                                                       timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
   259                                                       class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
   260                                                           Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
   261                                                       return_dict (`bool`, *optional*, defaults to `True`):
   262                                                           Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
   263                                           
   264                                                   Returns:
   265                                                       [`~models.unet_2d.UNet2DOutput`] or `tuple`:
   266                                                           If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
   267                                                           returned where the first element is the sample tensor.
   268                                                   """
   269                                                   # 0. center input if necessary
   270        25        111.8      4.5      0.0          if self.config.center_input_sample:
   271                                                       sample = 2 * sample - 1.0
   272                                           
   273                                                   # 1. time
   274        25          7.3      0.3      0.0          timesteps = timestep
   275        25        139.4      5.6      0.0          if not torch.is_tensor(timesteps):
   276                                                       timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
   277        25         93.4      3.7      0.0          elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
   278        25        616.4     24.7      0.0              timesteps = timesteps[None]
   279                                           
   280                                                   # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
   281        25      13295.6    531.8      0.1          timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
   282                                           
   283        25      57290.1   2291.6      0.4          t_emb = self.time_proj(timesteps)
   284                                           
   285                                                   # timesteps does not contain any weights and will always return f32 tensors
   286                                                   # but time_embedding might actually be running in fp16. so we need to cast here.
   287                                                   # there might be better ways to encapsulate this.
   288        25      39865.4   1594.6      0.3          t_emb = t_emb.to(dtype=self.dtype)
   289        25     114592.1   4583.7      0.9          emb = self.time_embedding(t_emb)
   290                                           
   291        25        169.0      6.8      0.0          if self.class_embedding is not None:
   292        25          8.2      0.3      0.0              if class_labels is None:
   293                                                           raise ValueError("class_labels should be provided when doing class conditioning")
   294                                           
   295        25         78.7      3.1      0.0              if self.config.class_embed_type == "timestep":
   296                                                           class_labels = self.time_proj(class_labels)
   297                                           
   298        25     318646.7  12745.9      2.5              class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
   299        25      10447.8    417.9      0.1              emb = emb + class_emb
   300                                                   elif self.class_embedding is None and class_labels is not None:
   301                                                       raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
   302                                           
   303                                                   # 2. pre-process
   304        25          7.7      0.3      0.0          skip_sample = sample
   305        25    1352664.2  54106.6     10.5          sample = self.conv_in(sample)
   306                                           
   307                                                   # 3. down
   308        25         38.2      1.5      0.0          down_block_res_samples = (sample,)
   309       125        906.6      7.3      0.0          for downsample_block in self.down_blocks:
   310       100       1104.8     11.0      0.0              if hasattr(downsample_block, "skip_conv"):
   311                                                           sample, res_samples, skip_sample = downsample_block(
   312                                                               hidden_states=sample, temb=emb, skip_sample=skip_sample
   313                                                           )
   314                                                       else:
   315       100    5186059.4  51860.6     40.3                  sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
   316                                           
   317       100        121.6      1.2      0.0              down_block_res_samples += res_samples
   318                                           
   319                                                   # 4. mid
   320        25     441380.1  17655.2      3.4          sample = self.mid_block(sample, emb)
   321                                           
   322                                                   # 5. up
   323        25         15.7      0.6      0.0          skip_sample = None
   324       125        577.8      4.6      0.0          for upsample_block in self.up_blocks:
   325       100       1319.4     13.2      0.0              res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
   326       100        222.8      2.2      0.0              down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
   327                                           
   328       100        560.0      5.6      0.0              if hasattr(upsample_block, "skip_conv"):
   329                                                           sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
   330                                                       else:
   331       100    3741789.3  37417.9     29.1                  sample = upsample_block(sample, res_samples, emb)
   332                                           
   333                                                   # 6. post-process
   334        25      21149.3    846.0      0.2          sample = self.conv_norm_out(sample)
   335        25      72442.4   2897.7      0.6          sample = self.conv_act(sample)
   336        25    1477399.9  59096.0     11.5          sample = self.conv_out(sample)
   337                                           
   338        25         45.9      1.8      0.0          if skip_sample is not None:
   339                                                       sample += skip_sample
   340                                           
   341        25        291.7     11.7      0.0          if self.config.time_embedding_type == "fourier":
   342                                                       timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
   343                                                       sample = sample / timesteps
   344                                           
   345        25         16.6      0.7      0.0          if not return_dict:
   346                                                       return (sample,)
   347                                           
   348        25       1563.5     62.5      0.0          return UNet2DOutput(sample=sample)

This shows that there is a lot of hang up on lines 315 and 331 where the downsampling and upsampling take place. Again, I did not change anything in the default UNet2DModel aside from the input channels being 6 instead of 3. Was there something I was supposed to set (maybe something about the skip connections according to what is above) when training a custom model from scratch?

Additionally, I found that line 278 ate up a lot of time when .to() was used so I made it so the timesteps are set to the correct device outside of the sampling loop. This sped up my iterations from 2 iters/sec to ~6 iters/sec when using a 512x512 input but the time it takes to complete is still over two minutes.

Lastly, for further clarification, the line I added in the pipeline to match the masking in the Palette implementation was this at the end of each timestep in the sampling loop:

# mask image
y_t = original_image*(1.-mask_image) + mask_image*y_t

Using this profiler I checked to see if this caused any slow downs and it did not. The slow down was still coming from what I had stated above.

Would anyone have some insight on why the lines I listed above we be causing such a slow down?