I have been reading about Unets and Stable diffusion and want to train one. I understand the original architecture for unets and how its channels, height and width evolve over down blocks and up blocks.
Just to be sure, this is my understanding: for up blocks if there are 2 resnet blocks: only the first conv2d operation of the first resnet block have double no of channels( because it receives input from previous midblock/upblock + skip connection). And for the rest resnet blocks the in channels is same as the outchannels of previous resnet block. Is this correct?
Now, when I implement the following Unet2D model using diffusers library, I am confused regarding working of resnet blocks for Up blocks.
from diffusers import UNet2DModel
# We'll train on 64-pixel square images
image_size = 64
# Create a model
model = UNet2DModel(
sample_size=image_size, # the target image resolution
in_channels=4, # the number of input channels, latent image 4, pixel space 3
out_channels=4, # the number of output channels
layers_per_block=2, # how many ResNet layers to use per UNet block
block_out_channels=(64, 128,256, 512),
down_block_types=(
"DownBlock2D", # a regular ResNet downsampling block
"DownBlock2D",
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
"AttnDownBlock2D",
),
up_block_types=(
"AttnUpBlock2D",
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
"UpBlock2D",
"UpBlock2D", # a regular ResNet upsampling block
),
)
My Questions:
- Even when layers_per_block = 2, each downblock has 2 resnet block but up blocks have 3 resnet blocks, why?
- Also, For the resnet blocks in the first AttnUpBlock2D, there exists 3 resnet blocks: each resnet blocks has conv2d operation happening twice. So it shows following structure:
1st resnet block:
Conv2d (in_channels = 1024, out_channels = 512)
Conv2d(in_channels = 512, out_channels = 512)
2nd resnet block:
Conv2d(in_channels = 1024, out_channels = 512)
Conv2d(in_channels = 512, out_channels = 512)
3rd resnet block:
Conv2d(in_channels =768, out_channels = 512)
Conv2d(in_channels = 512, out_channels = 512)
if my input latents(height=width=64, channels= 4) then the mid block will pass 8 * 8 * 512 and last down block skip connection will also pass 8 * 8 * 512. So i understand the 1st conv2d in 1st resnet blocks will have 1024 channels. But how does the 2nd resnet block have in channels 1024 and 3rd one has 768?
I am sorry if the question sounds too basic but i would like to understand the working. I tried drawing the entire architecture if in case that might be of any help to clear my confusion. it can be seen in diagram inside first up block that the last conv2d of 1st resnetblock has channels 512 but in_channels for next resnet block were 1024