Can I use a custom attention layer while still leveraging a pre-trained BERT model?

In the paper ā€œUsing Prior Knowledge to Guide BERT’s Attention in Semantic Textual Matching Tasksā€, they multiply a similarity matrix with the attention scores inside the attention layer. I want to apply a similar customization, but I also want to take advantage of the pre-trained BERT weights since I don’t have enough resources to train from scratch.

Is this possible?

Here’s what I’m planning to do:


config = BertConfig.from_pretrained(pretrained_model_name)
self.pretrained_model = BertModel.from_pretrained(pretrained_model_name, config=config)
self.pretrained_model.encoder = MyCustomBertEncoder(config) ```

Then, inside MyCustomBertEncoder, I’ll define a custom transformer layer that uses my custom self-attention logic.
1 Like

Seems okay?

import torch
import torch.nn as nn
import math

from transformers import BertModel, BertConfig, BertPreTrainedModel
from transformers.models.bert.modeling_bert import BertSelfAttention, BertAttention, BertLayer, BertEncoder

# 1) Define a CustomSelfAttention subclass that injects your similarity weights
class CustomSelfAttention(BertSelfAttention):
    def __init__(self, config, similarity_matrix: torch.Tensor):
        super().__init__(config)
        # similarity_matrix: [num_heads, seq_len, seq_len]
        # Make it a buffer so it's moved with the model's device
        self.register_buffer("prior_sim", similarity_matrix)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        # Standard projections
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer   = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        # Reshape to [batch, heads, seq_len, head_dim]
        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer   = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Raw attention scores: [batch, heads, seq, seq]
        scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        scores = scores / math.sqrt(self.attention_head_size)

        # Inject prior knowledge: multiply elementwise
        # Expand prior_sim to batch dimension if needed
        # prior_sim: [1, heads, seq, seq] or [heads, seq, seq]
        prior = self.prior_sim
        if prior.dim() == 3:
            prior = prior.unsqueeze(0)  # [1, heads, seq, seq]
        scores = scores * prior

        # Standard masking, softmax, dropout
        if attention_mask is not None:
            scores = scores + attention_mask
        attn_probs = nn.Softmax(dim=-1)(scores)
        attn_probs = self.dropout(attn_probs)

        # Optionally apply head_mask
        if head_mask is not None:
            attn_probs = attn_probs * head_mask

        # Compute context
        context = torch.matmul(attn_probs, value_layer)
        context = context.permute(0, 2, 1, 3).contiguous()
        new_context_shape = context.size()[:-2] + (self.all_head_size,)
        context = context.view(*new_context_shape)

        outputs = (context, attn_probs) if output_attentions else (context,)
        return outputs

# 2) Utility to replace all BertSelfAttention modules in the model
def inject_custom_attention(model: BertModel, sim_matrix: torch.Tensor):
    for layer in model.encoder.layer:
        # layer.attention is a BertAttention: has .self (BertSelfAttention) and .output
        layer.attention.self = CustomSelfAttention(model.config, sim_matrix)
    return model

# 3) Example usage
def main():
    # Load config + pretrained weights
    model_name = "bert-base-uncased"
    config = BertConfig.from_pretrained(model_name)
    bert = BertModel.from_pretrained(model_name, config=config)

    # Suppose your similarity matrix is fixed for sequence length 128 and 12 heads:
    seq_len = 128
    num_heads = config.num_attention_heads
    # For demo, use all-ones (no effect); replace with your real prior:
    sim = torch.ones((num_heads, seq_len, seq_len))

    # Inject
    bert = inject_custom_attention(bert, sim)

    # Freeze everything except attention parameters
    for name, param in bert.named_parameters():
        if "attention.self" not in name:
            param.requires_grad = False

    # Dummy input
    input_ids = torch.arange(seq_len).unsqueeze(0)  # [1, seq_len]
    attention_mask = torch.ones_like(input_ids)
    # Forward
    outputs = bert(input_ids=input_ids, attention_mask=attention_mask)
    last_hidden = outputs.last_hidden_state  # [1, seq_len, hidden_dim]

    # Now fine-tune with your task-specific head...
    print("Output shape:", last_hidden.shape)

if __name__ == "__main__":
    main()
1 Like

(post deleted by author)

Thank you, this looks promising. However, I’m not sure how to pass the similarity matrix to the forward method of the Attention class during fine-tuning, since similarity matrix changes for each sample and I shouldn’t be reinitializing the Attention class for every single batch.

1 Like

Hmm… Like this?

import math
import torch
import torch.nn as nn
from transformers import BertConfig, BertModel, AutoTokenizer
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
from transformers.models.bert.modeling_bert import BertSelfAttention

# 1) Custom attention that accepts sim_matrix
class SimSelfAttention(BertSelfAttention):
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        output_attentions=False,
        sim_matrix=None,     # ← dynamic similarity
        **kwargs
    ):
        # Q/K/V projections (pretrained weights)
        q = self.query(hidden_states)
        k = self.key(hidden_states)
        v = self.value(hidden_states)

        # reshape to [B, heads, seq, dim]
        q = self.transpose_for_scores(q)
        k = self.transpose_for_scores(k)
        v = self.transpose_for_scores(v)

        # scaled dot-product
        scores = torch.matmul(q, k.transpose(-1, -2))
        scores = scores / math.sqrt(self.attention_head_size)

        # inject sim_matrix if provided
        if sim_matrix is not None:
            scores = scores * sim_matrix

        # apply mask, softmax, dropout
        if attention_mask is not None:
            scores = scores + attention_mask
        probs = nn.Softmax(dim=-1)(scores)
        probs = self.dropout(probs)
        if head_mask is not None:
            probs = probs * head_mask

        # compute context
        context = torch.matmul(probs, v)
        context = context.permute(0, 2, 1, 3).contiguous()
        context = context.view(context.size(0), context.size(1), -1)

        return (context, probs) if output_attentions else (context,)

# 2) Subclass BertModel—inject SimSelfAttention and accept sim_matrix
class CustomBertWithSim(BertModel):
    def __init__(self, config: BertConfig):
        super().__init__(config)
        # replace each layer's self-attention
        for layer in self.encoder.layer:
            layer.attention.self = SimSelfAttention(config)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        sim_matrix=None,     # ← accept here
    ):
        # mirror BertModel.forward defaults
        output_attentions    = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        return_dict          = return_dict if return_dict is not None else self.config.use_return_dict

        # 1) Embeddings
        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
        )

        # 2) Prepare masks
        if attention_mask is not None:
            # [B, seq] → [B,1,1,seq] additive mask
            extended_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            extended_mask = (1.0 - extended_mask) * -1e4
        else:
            extended_mask = None
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        # 3) Encoder loop (no extra layernorm!)
        all_hidden_states = () if output_hidden_states else None
        all_attentions    = () if output_attentions else None
        hidden_states     = embedding_output

        for i, layer_module in enumerate(self.encoder.layer):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            # 3a) Self-attention with sim_matrix
            attn_outputs = layer_module.attention.self(
                hidden_states,
                attention_mask=extended_mask,
                head_mask=head_mask[i] if head_mask is not None else None,
                output_attentions=output_attentions,
                sim_matrix=sim_matrix,   # ← forwarded
            )
            attn_output = layer_module.attention.output(attn_outputs[0], hidden_states)

            # 3b) Feed-forward
            intermediate_output = layer_module.intermediate(attn_output)
            hidden_states       = layer_module.output(intermediate_output, attn_output)

            if output_attentions:
                all_attentions += (attn_outputs[1],)

        # 4) Pooler (no extra encoder.layernorm)
        pooled_output = self.pooler(hidden_states) if self.pooler is not None else None

        # 5) Return in requested format
        if not return_dict:
            outputs = (hidden_states, pooled_output)
            if output_hidden_states:
                outputs += (all_hidden_states,)
            if output_attentions:
                outputs += (all_attentions,)
            return outputs

        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=hidden_states,
            pooler_output=pooled_output,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
        )

if __name__ == "__main__":
    model_name = "bert-base-uncased"
    config     = BertConfig.from_pretrained(model_name)
    model      = CustomBertWithSim.from_pretrained(model_name, config=config)

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    texts     = ["Hello world!", "How are you?"]
    enc       = tokenizer(texts, padding="max_length", truncation=True, max_length=16, return_tensors="pt")

    # 4D attention mask
    attn_mask = enc.attention_mask.unsqueeze(1).unsqueeze(2).to(torch.float32)
    attn_mask = (1.0 - attn_mask) * -1e4

    # dummy sim_matrix [B, heads, seq, seq]
    B, S = enc.input_ids.shape
    H     = model.config.num_attention_heads
    sim   = torch.rand((B, H, S, S))

    outputs = model(input_ids=enc.input_ids, attention_mask=enc.attention_mask, sim_matrix=sim)
    print(outputs.last_hidden_state.shape)  # → (2, 16, 768)