Hi, I was thinking of adding cross attention between a visual transformer and a bert model. Was wondering if there was a way that I could do this using the HF library.
What I was thinking was if somewhere in the HF Bert model API if I had access to where it took in the queries, keys, and values, I could subclass the BERT submodule and add cross attention instead of just having self attention. I’m visualizing something very much like this code snippet from the annotated transformer paper. Specifically the DecodeLayer
class where there is self_attn
as well an additional src_attn
which I would need to add in.
I am also aware that I would need to copy the weights for everything but the src_attn
module. Just need some mechanism to do so. Fingers crossed there is some place in the HF API that I can do exactly that.
Happy to do this myself if someone can point where in the HF library I should be looking at to see where it uses queries, keys and values arguments.