Problem with Adding LayerNorm after BART's Encoder for Summarization

Moving the discussion from Issues to here.

I am trying to add additional layers/encoders after the BARTEncoder that involves all the self attention and layernorm layers, and after debugging I find that whenever I call the layernorm, the model cannot give reasonable rouge at test time. Here is the minimal reproduction code.

  1. I used the examples/pytorch/summarization/run_summarization.py . The changes I make (which I think are harmless is commenting the version requirement and calling my own Model BARTForConditionalGenerationTest (which I am pasting below). So the change is model = BARTForConditionalGenerationTest.from_pretrained( instead of model = AutoModelForSeq2SeqLM.from_pretrained(
  2. The testing model adds the self attention+layernorm module, which I copied directly from BartEncoderLayer:
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from dataclasses import dataclass
from typing import Optional, Tuple

from transformers.models.bart.modeling_bart import (
    BartForConditionalGeneration,
    BartModel,
    BartDecoder,
    BartEncoder,
    BartAttention,
    shift_tokens_right,
    _expand_mask,
)

from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
    Seq2SeqModelOutput,
    Seq2SeqLMOutput,
    BaseModelOutput,
)

class BARTModelTest(BartModel):
    def __init__(self, config):
        super().__init__(config)

        # additional layer to showcase the layernorm issue
        self.embed_dim = config.d_model
        self.self_attn = BartAttention(
            embed_dim=self.embed_dim,
            num_heads=config.encoder_attention_heads,
            dropout=config.attention_dropout,
        )
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]
        self.activation_dropout = config.activation_dropout
        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        # different to other models, Bart automatically creates decoder_input_ids from
        # input_ids if no decoder_input_ids are provided
        if decoder_input_ids is None and decoder_inputs_embeds is None:
            if input_ids is None:
                raise ValueError(
                    "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
                    "passed, `input_ids` cannot be `None`. Please pass either "
                    "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
                )

            decoder_input_ids = shift_tokens_right(
                input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
            )

        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        
        # NEW: Pass to another self attention
        hidden_states = encoder_outputs.last_hidden_state

        residual = hidden_states

        _attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
        hidden_states, attn_weights, _ = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=_attention_mask,
        )
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = residual + hidden_states
        # Problematic LayerNorm Layer
        hidden_states = self.self_attn_layer_norm(hidden_states)

        residual = hidden_states
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
        hidden_states = self.fc2(hidden_states)
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = residual + hidden_states
        # Problematic LayerNorm Layer
        hidden_states = self.final_layer_norm(hidden_states)

        encoder_outputs.last_hidden_state = hidden_states


        # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_outputs.last_hidden_state,
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if not return_dict:
            return decoder_outputs + encoder_outputs

        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_attentions=encoder_outputs.attentions,
        )


class BARTForConditionalGenerationTest(BartForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
        self.model = BARTModelTest(config)
        # Initialize weights and apply final processing
        self.post_init()

notice the lines I start with the comment # NEW
3. Running this on XSum with just one gpu (I have tried multi-gpu, deepspeed and so forth):

--dataset_name xsum --do_train \
--model_name facebook/bart-base \
--tokenizer_name facebook/bart-base \
--do_eval --evaluation_strategy steps --eval_steps 10  --predict_with_generate \
--per_device_train_batch_size 64 --per_device_eval_batch_size 16 \
--gradient_accumulation_steps 1 \
--learning_rate 3e-05 --weight_decay 0.01 --label_smoothing 0.1 \
--max_source_length 512 --max_target_length 64 \
--logging_step 100 --max_steps 5000 \
--warmup_steps 0 --save_steps 1000  \
--output_dir test_layernorm --max_eval_samples 10 --max_train_samples 1000 --max_predict_samples 100

I stop this after 30 steps.

---- Results ----

  1. Running this with original AutoModelForSeq2SeqLM
{'eval_loss': 3.429733991622925, 'eval_rouge1': 35.3788, 'eval_rouge2': 11.958, 'eval_rougeL': 28.7712, 'eval_rougeLsum': 28.8147, 'eval_gen_len': 19.6, 'eval_runtime': 0.4073, 'eval_samples_per_second': 24.552, 'eval_steps_per_second': 2.455, 'epoch': 0.62}
  0%|â–‹                                                                                                                                                                                    | 20/5000 [00:10<40:09,  2.07it/s][INFO|trainer.py:2590] 2022-05-15 14:54:19,166 >> ***** Running Evaluation *****
[INFO|trainer.py:2592] 2022-05-15 14:54:19,166 >>   Num examples = 10
[INFO|trainer.py:2595] 2022-05-15 14:54:19,166 >>   Batch size = 16
                                                                                                                                                                                                                           05/15/2022 14:54:19 - INFO - datasets.metric - Removing /home/davidwan/.cache/huggingface/metrics/rouge/default/default_experiment-1-0.arrow                                                           | 0/1 [00:00<?, ?it/s]
{'eval_loss': 3.320158004760742, 'eval_rouge1': 30.3056, 'eval_rouge2': 10.7887, 'eval_rougeL': 28.2016, 'eval_rougeLsum': 28.0782, 'eval_gen_len': 19.8, 'eval_runtime': 0.3998, 'eval_samples_per_second': 25.01, 'eval_steps_per_second': 2.501, 'epoch': 1.25}
  1%|â–ˆ                                                                                                                                                                                    | 30/5000 [00:15<41:45,  1.98it/s][INFO|trainer.py:2590] 2022-05-15 14:54:24,528 >> ***** Running Evaluation *****
[INFO|trainer.py:2592] 2022-05-15 14:54:24,528 >>   Num examples = 10
[INFO|trainer.py:2595] 2022-05-15 14:54:24,528 >>   Batch size = 16
                                                                                                                                                                                                                           05/15/2022 14:54:24 - INFO - datasets.metric - Removing /home/davidwan/.cache/huggingface/metrics/rouge/default/default_experiment-1-0.arrow                                                           | 0/1 [00:00<?, ?it/s]
{'eval_loss': 3.2896971702575684, 'eval_rouge1': 30.415, 'eval_rouge2': 8.1278, 'eval_rougeL': 27.7237, 'eval_rougeLsum': 27.6498, 'eval_gen_len': 20.0, 'eval_runtime': 0.3894, 'eval_samples_per_second': 25.681, 'eval_steps_per_second': 2.568, 'epoch': 1.88}
  1. Running with my model but commenting out the two lines that calls the layernorms (i.e. hidden_states = self.self_attn_layer_norm(hidden_states) and hidden_states = self.final_layer_norm(hidden_states) )
{'eval_loss': 3.460312604904175, 'eval_rouge1': 32.4359, 'eval_rouge2': 9.7464, 'eval_rougeL': 27.5792, 'eval_rougeLsum': 27.4135, 'eval_gen_len': 19.1, 'eval_runtime': 1.0524, 'eval_samples_per_second': 9.502, 'eval_steps_per_second': 0.95, 'epoch': 0.62}
  0%|â–‹                                                                                                                                                                                    | 20/5000 [00:12<46:20,  1.79it/s][INFO|trainer.py:2590] 2022-05-15 14:57:13,684 >> ***** Running Evaluation *****
[INFO|trainer.py:2592] 2022-05-15 14:57:13,684 >>   Num examples = 10
[INFO|trainer.py:2595] 2022-05-15 14:57:13,684 >>   Batch size = 16
                                                                                                                                                                                                                           05/15/2022 14:57:14 - INFO - datasets.metric - Removing /home/davidwan/.cache/huggingface/metrics/rouge/default/default_experiment-1-0.arrow                                                           | 0/1 [00:00<?, ?it/s]
{'eval_loss': 3.37113881111145, 'eval_rouge1': 29.4708, 'eval_rouge2': 7.4381, 'eval_rougeL': 24.7256, 'eval_rougeLsum': 24.5516, 'eval_gen_len': 19.9, 'eval_runtime': 0.7387, 'eval_samples_per_second': 13.538, 'eval_steps_per_second': 1.354, 'epoch': 1.25}
  1%|â–ˆ                                                                                                                                                                                    | 30/5000 [00:18<47:48,  1.73it/s][INFO|trainer.py:2590] 2022-05-15 14:57:20,076 >> ***** Running Evaluation *****
[INFO|trainer.py:2592] 2022-05-15 14:57:20,076 >>   Num examples = 10
[INFO|trainer.py:2595] 2022-05-15 14:57:20,076 >>   Batch size = 16
                                                                                                                                                                                                                           05/15/2022 14:57:20 - INFO - datasets.metric - Removing /home/davidwan/.cache/huggingface/metrics/rouge/default/default_experiment-1-0.arrow                                                           | 0/1 [00:00<?, ?it/s]
{'eval_loss': 3.33235239982605, 'eval_rouge1': 33.9623, 'eval_rouge2': 11.8778, 'eval_rougeL': 30.1785, 'eval_rougeLsum': 30.1524, 'eval_gen_len': 19.7, 'eval_runtime': 0.7438, 'eval_samples_per_second': 13.444, 'eval_steps_per_second': 1.344, 'epoch': 1.88}
  1. Running my model with the layernorms:
{'eval_loss': 9.264244079589844, 'eval_rouge1': 8.4575, 'eval_rouge2': 0.0, 'eval_rougeL': 7.8523, 'eval_rougeLsum': 7.8706, 'eval_gen_len': 20.0, 'eval_runtime': 0.7076, 'eval_samples_per_second': 14.133, 'eval_steps_per_second': 1.413, 'epoch': 0.62}
  0%|â–‹                                                                                                                                                                                    | 20/5000 [00:11<45:57,  1.81it/s][INFO|trainer.py:2590] 2022-05-15 14:58:27,171 >> ***** Running Evaluation *****
[INFO|trainer.py:2592] 2022-05-15 14:58:27,172 >>   Num examples = 10
[INFO|trainer.py:2595] 2022-05-15 14:58:27,172 >>   Batch size = 16
                                                                                                                                                                                                                           05/15/2022 14:58:27 - INFO - datasets.metric - Removing /home/davidwan/.cache/huggingface/metrics/rouge/default/default_experiment-1-0.arrow                                                           | 0/1 [00:00<?, ?it/s]
{'eval_loss': 8.134066581726074, 'eval_rouge1': 14.0672, 'eval_rouge2': 1.2222, 'eval_rougeL': 12.6982, 'eval_rougeLsum': 13.1708, 'eval_gen_len': 18.3, 'eval_runtime': 0.7573, 'eval_samples_per_second': 13.205, 'eval_steps_per_second': 1.32, 'epoch': 1.25}
  1%|â–ˆ                                                                                                                                                                                    | 30/5000 [00:17<47:47,  1.73it/s][INFO|trainer.py:2590] 2022-05-15 14:58:33,581 >> ***** Running Evaluation *****
[INFO|trainer.py:2592] 2022-05-15 14:58:33,581 >>   Num examples = 10
[INFO|trainer.py:2595] 2022-05-15 14:58:33,581 >>   Batch size = 16
                                                                                                                                                                                                                           05/15/2022 14:58:34 - INFO - datasets.metric - Removing /home/davidwan/.cache/huggingface/metrics/rouge/default/default_experiment-1-0.arrow                                                           | 0/1 [00:00<?, ?it/s]
{'eval_loss': 7.54071569442749, 'eval_rouge1': 5.2054, 'eval_rouge2': 0.0, 'eval_rougeL': 5.0935, 'eval_rougeLsum': 5.1303, 'eval_gen_len': 11.5, 'eval_runtime': 0.7393, 'eval_samples_per_second': 13.526, 'eval_steps_per_second': 1.353, 'epoch': 1.88}

I understand that the self_attn, feedforward, and the layernorms are newly intialized, but I expect them to be trained and updated and get similar performance. As you can see in my second run where I have the self_attn and feedforward (but no layernorm), it is updating correctly and achieving similar performance than without these additions (regular BART). However, only adding the layernorm to it makes the model unusable (the third run), and I have no clue why that is.

Any help is appreciated!