Hmm⦠Like this?
import math
import torch
import torch.nn as nn
from transformers import BertConfig, BertModel, AutoTokenizer
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
from transformers.models.bert.modeling_bert import BertSelfAttention
# 1) Custom attention that accepts sim_matrix
class SimSelfAttention(BertSelfAttention):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
sim_matrix=None, # ā dynamic similarity
**kwargs
):
# Q/K/V projections (pretrained weights)
q = self.query(hidden_states)
k = self.key(hidden_states)
v = self.value(hidden_states)
# reshape to [B, heads, seq, dim]
q = self.transpose_for_scores(q)
k = self.transpose_for_scores(k)
v = self.transpose_for_scores(v)
# scaled dot-product
scores = torch.matmul(q, k.transpose(-1, -2))
scores = scores / math.sqrt(self.attention_head_size)
# inject sim_matrix if provided
if sim_matrix is not None:
scores = scores * sim_matrix
# apply mask, softmax, dropout
if attention_mask is not None:
scores = scores + attention_mask
probs = nn.Softmax(dim=-1)(scores)
probs = self.dropout(probs)
if head_mask is not None:
probs = probs * head_mask
# compute context
context = torch.matmul(probs, v)
context = context.permute(0, 2, 1, 3).contiguous()
context = context.view(context.size(0), context.size(1), -1)
return (context, probs) if output_attentions else (context,)
# 2) Subclass BertModelāinject SimSelfAttention and accept sim_matrix
class CustomBertWithSim(BertModel):
def __init__(self, config: BertConfig):
super().__init__(config)
# replace each layer's self-attention
for layer in self.encoder.layer:
layer.attention.self = SimSelfAttention(config)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
sim_matrix=None, # ā accept here
):
# mirror BertModel.forward defaults
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 1) Embeddings
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
)
# 2) Prepare masks
if attention_mask is not None:
# [B, seq] ā [B,1,1,seq] additive mask
extended_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_mask = (1.0 - extended_mask) * -1e4
else:
extended_mask = None
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# 3) Encoder loop (no extra layernorm!)
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = embedding_output
for i, layer_module in enumerate(self.encoder.layer):
if output_hidden_states:
all_hidden_states += (hidden_states,)
# 3a) Self-attention with sim_matrix
attn_outputs = layer_module.attention.self(
hidden_states,
attention_mask=extended_mask,
head_mask=head_mask[i] if head_mask is not None else None,
output_attentions=output_attentions,
sim_matrix=sim_matrix, # ā forwarded
)
attn_output = layer_module.attention.output(attn_outputs[0], hidden_states)
# 3b) Feed-forward
intermediate_output = layer_module.intermediate(attn_output)
hidden_states = layer_module.output(intermediate_output, attn_output)
if output_attentions:
all_attentions += (attn_outputs[1],)
# 4) Pooler (no extra encoder.layernorm)
pooled_output = self.pooler(hidden_states) if self.pooler is not None else None
# 5) Return in requested format
if not return_dict:
outputs = (hidden_states, pooled_output)
if output_hidden_states:
outputs += (all_hidden_states,)
if output_attentions:
outputs += (all_attentions,)
return outputs
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=hidden_states,
pooler_output=pooled_output,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
if __name__ == "__main__":
model_name = "bert-base-uncased"
config = BertConfig.from_pretrained(model_name)
model = CustomBertWithSim.from_pretrained(model_name, config=config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
texts = ["Hello world!", "How are you?"]
enc = tokenizer(texts, padding="max_length", truncation=True, max_length=16, return_tensors="pt")
# 4D attention mask
attn_mask = enc.attention_mask.unsqueeze(1).unsqueeze(2).to(torch.float32)
attn_mask = (1.0 - attn_mask) * -1e4
# dummy sim_matrix [B, heads, seq, seq]
B, S = enc.input_ids.shape
H = model.config.num_attention_heads
sim = torch.rand((B, H, S, S))
outputs = model(input_ids=enc.input_ids, attention_mask=enc.attention_mask, sim_matrix=sim)
print(outputs.last_hidden_state.shape) # ā (2, 16, 768)