I was able to do some profiling and found that the problem is coming from the UNet’s forward function. I have the results of profiling listed below (done at 256x256 instead of 512x512):
Line # Hits Time Per Hit % Time Line Contents
244 @profile
245 def forward(
246 self,
247 sample: torch.FloatTensor,
248 timestep: Union[torch.Tensor, float, int],
249 class_labels: Optional[torch.Tensor] = None,
250 return_dict: bool = True,
251 ) -> Union[UNet2DOutput, Tuple]:
252 r"""
253 The [`UNet2DModel`] forward method.
255 Args:
256 sample (`torch.FloatTensor`):
257 The noisy input tensor with the following shape `(batch, channel, height, width)`.
258 timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
259 class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
260 Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
261 return_dict (`bool`, *optional*, defaults to `True`):
262 Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
264 Returns:
265 [`~models.unet_2d.UNet2DOutput`] or `tuple`:
266 If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
267 returned where the first element is the sample tensor.
268 """
269 # 0. center input if necessary
270 25 111.8 4.5 0.0 if self.config.center_input_sample:
271 sample = 2 * sample - 1.0
273 # 1. time
274 25 7.3 0.3 0.0 timesteps = timestep
275 25 139.4 5.6 0.0 if not torch.is_tensor(timesteps):
276 timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
277 25 93.4 3.7 0.0 elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
278 25 616.4 24.7 0.0 timesteps = timesteps[None]
280 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
281 25 13295.6 531.8 0.1 timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
283 25 57290.1 2291.6 0.4 t_emb = self.time_proj(timesteps)
285 # timesteps does not contain any weights and will always return f32 tensors
286 # but time_embedding might actually be running in fp16. so we need to cast here.
287 # there might be better ways to encapsulate this.
288 25 39865.4 1594.6 0.3 t_emb = t_emb.to(dtype=self.dtype)
289 25 114592.1 4583.7 0.9 emb = self.time_embedding(t_emb)
291 25 169.0 6.8 0.0 if self.class_embedding is not None:
292 25 8.2 0.3 0.0 if class_labels is None:
293 raise ValueError("class_labels should be provided when doing class conditioning")
295 25 78.7 3.1 0.0 if self.config.class_embed_type == "timestep":
296 class_labels = self.time_proj(class_labels)
298 25 318646.7 12745.9 2.5 class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
299 25 10447.8 417.9 0.1 emb = emb + class_emb
300 elif self.class_embedding is None and class_labels is not None:
301 raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
303 # 2. pre-process
304 25 7.7 0.3 0.0 skip_sample = sample
305 25 1352664.2 54106.6 10.5 sample = self.conv_in(sample)
307 # 3. down
308 25 38.2 1.5 0.0 down_block_res_samples = (sample,)
309 125 906.6 7.3 0.0 for downsample_block in self.down_blocks:
310 100 1104.8 11.0 0.0 if hasattr(downsample_block, "skip_conv"):
311 sample, res_samples, skip_sample = downsample_block(
312 hidden_states=sample, temb=emb, skip_sample=skip_sample
313 )
314 else:
315 100 5186059.4 51860.6 40.3 sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
317 100 121.6 1.2 0.0 down_block_res_samples += res_samples
319 # 4. mid
320 25 441380.1 17655.2 3.4 sample = self.mid_block(sample, emb)
322 # 5. up
323 25 15.7 0.6 0.0 skip_sample = None
324 125 577.8 4.6 0.0 for upsample_block in self.up_blocks:
325 100 1319.4 13.2 0.0 res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
326 100 222.8 2.2 0.0 down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
328 100 560.0 5.6 0.0 if hasattr(upsample_block, "skip_conv"):
329 sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
330 else:
331 100 3741789.3 37417.9 29.1 sample = upsample_block(sample, res_samples, emb)
333 # 6. post-process
334 25 21149.3 846.0 0.2 sample = self.conv_norm_out(sample)
335 25 72442.4 2897.7 0.6 sample = self.conv_act(sample)
336 25 1477399.9 59096.0 11.5 sample = self.conv_out(sample)
338 25 45.9 1.8 0.0 if skip_sample is not None:
339 sample += skip_sample
341 25 291.7 11.7 0.0 if self.config.time_embedding_type == "fourier":
342 timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
343 sample = sample / timesteps
345 25 16.6 0.7 0.0 if not return_dict:
346 return (sample,)
348 25 1563.5 62.5 0.0 return UNet2DOutput(sample=sample)
This shows that there is a lot of hang up on lines 315 and 331 where the downsampling and upsampling take place. Again, I did not change anything in the default UNet2DModel aside from the input channels being 6 instead of 3. Was there something I was supposed to set (maybe something about the skip connections according to what is above) when training a custom model from scratch?
Additionally, I found that line 278 ate up a lot of time when .to() was used so I made it so the timesteps are set to the correct device outside of the sampling loop. This sped up my iterations from 2 iters/sec to ~6 iters/sec when using a 512x512 input but the time it takes to complete is still over two minutes.
Lastly, for further clarification, the line I added in the pipeline to match the masking in the Palette implementation was this at the end of each timestep in the sampling loop:
# mask image
y_t = original_image*(1.-mask_image) + mask_image*y_t
Using this profiler I checked to see if this caused any slow downs and it did not. The slow down was still coming from what I had stated above.
Would anyone have some insight on why the lines I listed above we be causing such a slow down?