Torch firing wrong forward_hook

This is a x-post. However, the root problem may be more HF than it is torch.

Problem

I am trying to modify the HF UNet’s for diffusion models. I’m doing this by adding conditions during the down and up blocks. This is a minimal example of the problem. It seems the last down_block is firing before the first one. Nothing in the UNet model’s source code suggests that it should be firing in reverse order.

import torch
import torch.nn as nn
import torch.nn.functional as F

from diffusers import UNet2DConditionModel

# config
SD_MODEL = "runwayml/stable-diffusion-v1-5"
DIM = 15

unet = UNet2DConditionModel.from_pretrained(SD_MODEL, subfolder="unet")

bs = 2
timestep = torch.randint(0, 100, (bs,))
noise = torch.randn((bs, 4, 64, 64))
text_encoding = torch.randn((bs, 77, 768))
condition = torch.randn((bs, DIM))

DownOutput = tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]

class ConditionResnet(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.call_count = 0
        self.projector = nn.Linear(in_dim, out_dim)
        self.conv1 = nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1)
        self.non_linearity = F.silu
                        
    def forward(self, out: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
        self.call_count += 1
        input_vector = out
        out = self.conv1(out) + self.projector(condition)[:, :, None, None]
        return input_vector + self.non_linearity(out)

# down blocks return tuples, so need slightly modified version  
class ConditionResnetDown(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.condition_resnet = ConditionResnet(in_dim, out_dim)
        
    def forward(self, x: DownOutput, condition: torch.Tensor) -> DownOutput:
        return self.condition_resnet(x[0], condition), x[1]

class UNetWithConditions(nn.Module):
    def __init__(self, unet: nn.Module, col_channels: int, down_block_sizes: list[int], up_block_sizes: list[int]):
        super().__init__()
        self.unet = unet
        self.down_block_condition_resnets = nn.ModuleList([ConditionResnetDown(col_channels, out_channel) for out_channel in down_block_sizes])
        self.up_block_condition_resnets = nn.ModuleList([ConditionResnet(col_channels, out_channel) for out_channel in up_block_sizes])
        
        self.condition = None
        
        # forward hooks
        for i in range(len(self.unet.down_blocks)):
            self.unet.down_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.down_block_condition_resnets[i](outputs, self.condition))
        for i in range(len(self.unet.up_blocks)):
            self.unet.up_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.up_block_condition_resnets[i](outputs, self.condition))
        
    def forward(self, noise, timestep, text_encoding, condition):
        self.condition = condition
        out = self.unet(noise, timestep, text_encoding).sample
        self.condition = None
        return out

unet_with_conditions = UNetWithConditions(unet, DIM, [320, 640, 1280, 1280], [1280, 1280, 640, 320])
out2 = unet_with_conditions(noise, timestep, text_encoding, condition)

The reason I know that the last down_block is being fired is because I look at call_count of ConditionResnet via ( [a.condition_resnet.call_count for a in unet_with_conditions.down_block_condition_resnets], [a.call_count for a in unet_with_conditions.up_block_condition_resnets], ) and I get this: ([0, 0, 0, 1], [0, 0, 0, 0]).

Potential Causes

  • Is it possible that if the model was compiled (maybe with jit) that the wrong hook is being fired. If I was to modify the model code above and execute it, it seems to crash with the old code for some reason.

Error Log:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
~tmp/ipykernel_574/3305635741.py in <cell line: 2>()
      1 unet_with_conditions = UNetWithConditions(unet, DIM, [320, 640, 1280, 1280], [1280, 1280, 640, 320])
----> 2 out2 = unet_with_conditions(noise, timestep, text_encoding, condition)

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~tmp/ipykernel_574/2058376754.py in forward(self, noise, timestep, text_encoding, condition)
     59     def forward(self, noise, timestep, text_encoding, condition):
     60         self.condition = condition
---> 61         out = self.unet(noise, timestep, text_encoding).sample
     62         self.condition = None
     63         return out

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~nix/store/vzqny68wq33dcg4hkdala51n5vqhpnwc-python3-3.9.12/lib/python3.9/site-packages/diffusers/models/unet_2d_condition.py in forward(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, encoder_attention_mask, return_dict)
    795         for downsample_block in self.down_blocks:
    796             if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
--> 797                 sample, res_samples = downsample_block(
    798                     hidden_states=sample,
    799                     temb=emb,

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1213         if _global_forward_hooks or self._forward_hooks:
   1214             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
-> 1215                 hook_result = hook(self, input, result)
   1216                 if hook_result is not None:
   1217                     result = hook_result

~tmp/ipykernel_574/2058376754.py in <lambda>(module, inputs, outputs)
     53         # forward hooks
     54         for i in range(len(self.unet.down_blocks)):
---> 55             self.unet.down_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.down_block_condition_resnets[i](outputs, self.condition))
     56         for i in range(len(self.unet.up_blocks)):
     57             self.unet.up_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.up_block_condition_resnets[i](outputs, self.condition))

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~tmp/ipykernel_574/2058376754.py in forward(self, x, condition)
     40 
     41     def forward(self, x: DownOutput, condition: torch.Tensor) -> DownOutput:
---> 42         return self.condition_resnet(x[0], condition), x[1]
     43 
     44 class UNetWithConditions(nn.Module):

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~tmp/ipykernel_574/2058376754.py in forward(self, out, condition)
     30         self.call_count += 1
     31         input_vector = out
---> 32         out = self.conv1(out) + self.projector(condition)[:, :, None, None]
     33         return input_vector + self.non_linearity(out)
     34 

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/conv.py in forward(self, input)
    461 
    462     def forward(self, input: Tensor) -> Tensor:
--> 463         return self._conv_forward(input, self.weight, self.bias)
    464 
    465 class Conv3d(_ConvNd):

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    457                             weight, bias, self.stride,
    458                             _pair(0), self.dilation, self.groups)
--> 459         return F.conv2d(input, weight, bias, self.stride,
    460                         self.padding, self.dilation, self.groups)
    461 

RuntimeError: Given groups=1, weight of size [1280, 1280, 3, 3], expected input[2, 320, 32, 32] to have 1280 channels, but got 320 channels instead