Errors in 3D attention mask for T5ForConditionalGeneration

Hi,
I’m trying to fine-tune a T5 model (T5ForConditionalGeneration) with customizing 2D attention mask to 3D. I followed most (maybe all?) the tutorials, notebooks and code snippets from the Transformers library to understand what to do (thank you for the sources), but so far, I’m getting errors when I edited the modeling_t5.py by adding lines for 3d attention mask as follow code.

Error message that I’ve got,

Traceback (most recent call last):
  File "scripts/2-5_Train_T5model_custom.py", line 25, in <module>
    trainer.train()
  File "/home/customized_t5/models/trainer.py", line 79, in train
    trainer.train(resume_from_checkpoint=self.training_args.resume_from_checkpoint)
  File "/opt/conda/envs/customized_t5/lib/python3.8/site-packages/transformers/trainer.py", line 1498, in train
    return inner_training_loop(
  File "/opt/conda/envs/customized_t5/lib/python3.8/site-packages/transformers/trainer.py", line 1740, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/opt/conda/envs/customized_t5/lib/python3.8/site-packages/transformers/trainer.py", line 2470, in training_step
    loss = self.compute_loss(model, inputs)
  File "/opt/conda/envs/customized_t5/lib/python3.8/site-packages/transformers/trainer.py", line 2502, in compute_loss
    outputs = model(**inputs)
  File "/opt/conda/envs/customized_t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/customized_t5/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/opt/conda/envs/customized_t5/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/opt/conda/envs/customized_t5/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/opt/conda/envs/customized_t5/lib/python3.8/site-packages/torch/_utils.py", line 461, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/opt/conda/envs/customized_t5/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/opt/conda/envs/customized_t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/customized_t5/lib/python3.8/site-packages/transformers/adapters/context.py", line 108, in wrapper_func
    results = f(self, *args, **kwargs)
  File "/home/customized_t5/models/t5/modeling_t5.py", line 1701, in forward
    decoder_outputs = self.decoder(
  File "/opt/conda/envs/customized_t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/customized_t5/models/t5/modeling_t5.py", line 1075, in forward
    layer_outputs = layer_module(
  File "/opt/conda/envs/customized_t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/customized_t5/models/t5/modeling_t5.py", line 724, in forward
    cross_attention_outputs = self.layer[1](
  File "/opt/conda/envs/customized_t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/customized_t5/models/t5/modeling_t5.py", line 633, in forward
    attention_output = self.EncDecAttention(
  File "/opt/conda/envs/customized_t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/customized_t5/models/t5/modeling_t5.py", line 551, in forward
    position_bias = position_bias + mask  # (batch_size, n_heads, seq_length, key_length)
RuntimeError: The size of tensor a (55) must match the size of tensor b (475) at non-singleton dimension 2

What I edited in T5Stack of modeling_t5.py

def forward(
        self,
        input_ids=None,
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        inputs_embeds=None,
        head_mask=None,
        encoder_head_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        type_ids=None,
        row_ids=None,
        col_ids=None,
    ):
        # Model parallel
        if self.model_parallel:
            torch.cuda.set_device(self.first_device)
            self.embed_tokens = self.embed_tokens.to(self.first_device)
        use_cache = use_cache if use_cache is not None else self.config.use_cache
...
#Edited----------------------------------------------------------------
if not self.is_decoder and len(attention_mask.shape) == 2:
    attention_mask = torch.bmm(torch.unsqueeze(attention_mask.float(), 2),
            torch.unsqueeze(attention_mask.float(), 1)) > 0.5
#------------------------------------------------------------------------

extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
...

What I edited in T5ForConditionalGeneration in modeling_t5.py

def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
...

# Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # EDITED----------------------------------------------------
            attention_mask = torch.bmm(torch.unsqueeze(attention_mask.float(), 2),
                    torch.unsqueeze(attention_mask.float(), 1)) > 0.5
            # ---------------------------------------------------------------
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

What am I missing here? Should I add(edit) more things in modeling_t5.py?

Thank you and hope have Happy New Year!