Adding cross-attention to custom models

Hi, I was thinking of adding cross attention between a visual transformer and a bert model. Was wondering if there was a way that I could do this using the HF library.

What I was thinking was if somewhere in the HF Bert model API if I had access to where it took in the queries, keys, and values, I could subclass the BERT submodule and add cross attention instead of just having self attention. I’m visualizing something very much like this code snippet from the annotated transformer paper. Specifically the DecodeLayer class where there is self_attn as well an additional src_attn which I would need to add in.

I am also aware that I would need to copy the weights for everything but the src_attn module. Just need some mechanism to do so. Fingers crossed there is some place in the HF API that I can do exactly that.

Happy to do this myself if someone can point where in the HF library I should be looking at to see where it uses queries, keys and values arguments.

bump. Sorry guys just wondering if anyone had any ideas about this.

Partial answer:

model1 = AutoModel.from_pretrained("gpt2")
gpt_config = model1.config
gpt_config.add_cross_attention = True
new_model = AutoModelForCausalLM.from_pretrained("gpt2", config=gpt_config)

Similarly for models like bert you need to do one additional step like this:

model1 = AutoModel.from_pretrained("bert-base-cased")
bert_config = model1.config
bert_config.add_cross_attention = True
bert_config.is_decoder = True
model2 = AutoModel.from_pretrained("bert-base-cased", config=bert_config)