3-dimensional attention_mask in LongformerSelfAttention

Hi everyone, I am learning NLP and I am doing a project that use distilroberta as an encoder model. I have read about K-BERT, and I want to implement that model with distilroberta. To say it briefly, K-BERT has a parallel Knowledge Graph with it. Each word in an input sentence will be queried from that Knowledge Graph and the information is injected to the input sentence beside that word. The information from the Knowledge Graph is just related to their token onwer. So, there is a visible matrix that decides which token is related to a token, it acts as an attention_mask an attention_mask is usually [batch, seq_len] but this kind of attention mask is [batch, seq_len, seq_len], and position index to indicate which token is related to the others. Therefore, when using distilroberta, I need to pass position as position_ids and visible matrix as attention_mask argument. Thankfully, in modeling_roberta it accept the attention_mask with 3 dimension and create an extended attention mask as [:, None, :, :], my visible matrix works. However, When I try to convert distilroberta to longformer, there is a problem in LongformerSelfAttention:

# values to pad for attention probs
remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None]

If attention_mask is [batch, seq_len]. After transforming like that, remove_from_windowed_attention_mask will be [batch, seq_len, 1, 1]. However, with my visible matrix, it results in [batch, seq_len, 1, 1, seq_len], then that has an error at:

def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int):
        """
        Matrix multiplication of query and key tensors using with a sliding window attention pattern. This
        implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an
        overlap of size window_overlap
        """
        batch_size, seq_len, num_heads, head_dim = query.size()
        assert (
            seq_len % (window_overlap * 2) == 0
        ), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}"
        assert query.size() == key.size()
 
 ValueError: too many values to unpack (expected 4)

I actually do not know how to solve this problem. I am afraid whether I change that attention mask transformation to be suitable for my visible matrix, it will affect something unexpectedly or not. I would be really appreciated because of your help.