Code from HF tutorial on the customization of transformer components is not working as intended

I tried to follow the instructions from this page

but could not even reproduce the example given there. At this point, I simply want to see that a model’s attention can be replaced with a custom modification. So, I copied the code and introduced an error in the attention initialization, to see it crash to make sure that the attention is actually replaced, which did not happen, indicating that the attention is not replaced by the custom module.

Am I missing something, or is the tutorial outdated?

import torch
import torch.nn as nn
from transformers.models.sam.modeling_sam import SamVisionAttention

class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
    def __init__(self, config, window_size):
        super().__init__(config, window_size)
        # remove combined qkv
        del self.qkv
        # separate q, k, v projections
        #self.q = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        # introduce error here
        self.q = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=config.qkv_bias)
        self.k = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self.v = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self._register_load_state_dict_pre_hook(self.split_q_k_v_load_hook)

    def split_q_k_v_load_hook(self, state_dict, prefix, *args):
        keys_to_delete = []
        for key in list(state_dict.keys()):
            if "qkv." in key:
                # split q, k, v from the combined projection
                q, k, v = state_dict[key].chunk(3, dim=0)
                # replace with individual q, k, v projections
                state_dict[key.replace("qkv.", "q.")] = q
                state_dict[key.replace("qkv.", "k.")] = k
                state_dict[key.replace("qkv.", "v.")] = v
                # mark the old qkv key for deletion
                keys_to_delete.append(key)

        # remove old qkv keys
        for key in keys_to_delete:
            del state_dict[key]

    def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
        batch_size, height, width, _ = hidden_states.shape
        qkv_shapes = (batch_size *  self.num_attention_heads,  height * width, -1)
        query = self.q(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
        key = self.k(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
        value = self.v(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)

        attn_weights = (query * self.scale) @ key.transpose(-2, -1)

        if self.use_rel_pos:
            attn_weights = self.add_decomposed_rel_pos(
                attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
            )

        attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
        attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
        attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
        attn_output = self.proj(attn_output)

        if output_attentions:
            outputs = (attn_output, attn_weights)
        else:
            outputs = (attn_output, None)
        return outputs

from transformers import SamModel
from transformers.models.sam import modeling_sam
from transformers.utils.import_utils import clear_import_cache

# replace the attention class in the modeling_sam module
modeling_sam.SamVisionAttention = SamVisionAttentionSplit

# clear cache to reload modified code
clear_import_cache()

# load the pretrained SAM model
model = SamModel.from_pretrained("facebook/sam-vit-base")

# Example input tensors
pixel_values = torch.randn(1, 3, 1024, 1024)  # Batch size 1, 3 channels, 1024x1024 image
original_sizes = torch.tensor([[1024, 1024]])  # Original size of the image

# Test the model
outputs = model(pixel_values=pixel_values, original_sizes=original_sizes)
print(outputs.iou_scores)
1 Like

It worked normally in my environment with 4.49.0dev. I think the code is outdated.

tensor([[[0.8896, 0.7523, 0.2491]]], grad_fn=)

I’m not very familiar with it, but Transformers seems to have been significantly revamped in 4.49.0 and 4.50.0, so I think it should work if you specify the following version for now.

pip install transformers<4.49.0

or

pip install transformers<=4.49.0
1 Like

Thanks a lot, I made the above example work by switching to pip install transformers==4.47.0, though I had to drop clear_import_cache() (which is more recent, as far as I understand).

However, I could not adapt this code to an OPT model. Again, my attention modifications had no effect. Should this type of model component customization work for all models or only some? If it is the latter case, can one understand, by looking at the model files, whether it’s going to work for a given model or not? Thank you for you help.

1 Like

I wonder…
When customizing attention, it may have been necessary to follow a different procedure.

1 Like

Thank you, I tried AttentionInterface before, it worked for the model in the example (meta-llama/Llama-3.2-1B), but not for OPT models. I have not figure out how to customize the attention directly. For now, the best solution I found was to subclass relevant classes to gain access to the attention, like so:

class myOPTModel(OPTModel):
  def __init__(self, config):
    super().__init__(config)
    self.decoder = myOPTDecoder(config)
    
class myOPTDecoder(OPTDecoder):
  def __init__(self, config):
    super().__init__(config)
    self.layers = nn.ModuleList([myOPTDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])

class myOPTDecoderLayer(OPTDecoderLayer):
  def __init__(self, config, layer_idx):
    super().__init__(config, layer_idx)
    self.self_attn = myOPTAttention(config, layer_idx)

class myOPTAttention(OPTAttention):
  def __init__(self, config, layer_idx):
    super().__init__(config, layer_idx)

  def forward(self, hidden_states, **kwargs):
    # my modification to forward() here

It seems straightforward, if a bit tedious.

1 Like