Remove causal mask from Llama decoder

Hi,

I want to train the recently released smaller Llama 3.2 models (link) for an NER task.

However, I would like to remove the causal LM triangular mask during training and inference. How do I go about this?

First of all, I thought the mask was automatically generated based on model.config.is_decoder in get_extended_attention_mask (link). If so, simply setting this to False should enable bidrectional attention.

However, model.config.is_decoder is False for AutoModelForCausalLM (as well as for AutoModelForTokenClassification).

from transformers import AutoTokenizer, AutoModelForCausalLM


model_checkpoint = "meta-llama/Llama-3.2-1B"

model = AutoModelForCausalLM.from_pretrained(model_checkpoint)
assert model.config.is_decoder == False

model = AutoModelForTokenClassification.from_pretrained(model_checkpoint)
assert model.config.is_decoder == False

How can this be the case? The flag should be True for AutoModelForCausalLM.

1 Like

Hi @abhinavkulkarni
I might propose the following patch, let me know if this works for you.

class MyModel(AutoModelForCausalLM) : 
   def get_extended_attention_mask(self, **kwargs,*args): #let's override the method here
       pass
1 Like

setting model.config.is_decoder to False is not going to work since the flag is already False for both AutoModelForCausalLM and AutoModelForTokenClassification. Exactly as the previous answer you can override the ‘ get_extended_attention_mask’.

But may I ask why you want to change the casual to bidirectional? NER task? Then why not using BERT?
Also, you could try using LLM2Vec.

1 Like

You need to override LlamaModel._update_causal_mask, or pass your own 4d attention bias as the attention mask argument, with shape 1, 1, L, L, where mask[0,0,i,j] = 0 unless i or j is padding in which case it equals minus some large number

1 Like

I was able to create a full attention mask as follows:

class LlamaFullAttentionModel(LlamaModel):
    def _update_causal_mask(
        self, *args, **kwargs
    ):
        causal_mask = self._update_causal_mask(*args, **kwargs)
        full_attention_mask = torch.tril(causal_mask)
        return full_attention_mask
2 Likes

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.