I’m trying to implement GLIGEN which adds gated-self-attention between self-attention and cross attention.
But I couldn’t know how to add this module in unet.
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
It seemed to be hard to just use for i, x in enumerate(unet.children()):
method for this case. Because the model has this structure.
0 : Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1 : Timesteps()
2 : TimestepEmbedding(
(linear_1): Linear(in_features=320, out_features=1280, bias=True)
(act): SiLU()
(linear_2): Linear(in_features=1280, out_features=1280, bias=True)
)
3 : ModuleList(
(0): CrossAttnDownBlock2D(
(attentions): ModuleList(
(0): Transformer2DModel(
(norm): GroupNorm(32, 320, eps=1e-06, affine=True)
(proj_in): Linear(in_features=320, out_features=320, bias=True)
(transformer_blocks): ModuleList(
(0): BasicTransformerBlock(
(attn1): CrossAttention(
(to_q): Linear(in_features=320, out_features=320, bias=False)
(to_k): Linear(in_features=320, out_features=320, bias=False)
(to_v): Linear(in_features=320, out_features=320, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=320, out_features=320, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(ff): FeedForward(
(net): ModuleList(
(0): GEGLU(
(proj): Linear(in_features=320, out_features=2560, bias=True)
)
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=1280, out_features=320, bias=True)
)
)
(attn2): CrossAttention(
(to_q): Linear(in_features=320, out_features=320, bias=False)
(to_k): Linear(in_features=1024, out_features=320, bias=False)
(to_v): Linear(in_features=1024, out_features=320, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=320, out_features=320, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
(norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
)
)
(proj_out): Linear(in_features=320, out_features=320, bias=True)
)
Is there way to add additional attention block in pretrained BasicTransformerBlock?