Can't load Longformer model build on top of MBART

Hello there,
I have built a Longformer Encoder Decoder on top of a MBart architecture by simply following instructions provided at (longformer/convert_bart_to_longformerencoderdecoder.py at master · allenai/longformer · GitHub).

This is the huggingface MBart model → ARTeLab/mbart-summarization-fanpage

In doing so I firstly updated any import methods called from the ‘transfomers’ library, secondly, since I am working on Google Colab to use a GPU, I moved all necessary classes into a .ipynb file.

When I try to load the model I get a size mismatch for model.encoder.embed_positions.weight error. I have tried to load the model calling different functions provided by transformers but none of them seem to be compatible with the model.

Interestingly, when I load the model via the LongformerModel.from_pretrained(load_model_from) function the model seems to be correctly loaded but I can’t find a way to make inference.

Snippets of code and more in details explanations are given below.

Environment info

  • transformers version: 4.17.0
  • Platform: Linux-5.4.144±x86_64-with-Ubuntu-18.04-bionic
  • Python version: 3.7.12
  • PyTorch version (GPU?): 1.10.0+cu111 (True)
  • Tensorflow version (GPU?): 2.8.0 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes, Backend (GPU) Google Compute Engine
  • Using distributed or parallel set-up in script?: no

Who can help

@ydshieh
@patil-suraj

Information

Model I am using (Longformer Encoder Decoder For Conditional Generation, MBART):

The problem arises when using:

  • model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained(load_model_from)
  • my own modified scripts: (give details below)

The task I am trying to work on is:

  • Summarization in Italian

To reproduce

Steps to reproduce the behavior:

  1. Run the code
import argparse
import logging
import os
import copy

from transformers import AutoTokenizer
from transformers import MBartForConditionalGeneration
from typing import List, Optional, Tuple, Dict
from torch import nn, Tensor
from transformers.models.longformer.modeling_longformer import LongformerSelfAttention
#from transformers.models.bart.modeling_bart import BartConfig
#from transformers.models.led.modeling_led import LEDForConditionalGeneration, LEDConfig
from transformers import MBartConfig


class LongformerSelfAttentionForBart(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.embed_dim = config.d_model
        self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id)
        self.output = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(
        self,
        query,
        key: Optional[Tensor],
        key_padding_mask: Optional[Tensor] = None,
        layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
        attn_mask: Optional[Tensor] = None,
        need_weights=False,
        output_attentions=False,
    ) -> Tuple[Tensor, Optional[Tensor]]:

        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]
        assert attn_mask is None

        outputs = self.longformer_self_attn(
            query.transpose(0, 1),  # LongformerSelfAttention expects (bsz, seqlen, embd_dim)
            attention_mask=key_padding_mask.unsqueeze(dim=1).unsqueeze(dim=1) * -1,
            head_mask=None,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            output_attentions=output_attentions,
        )

        attn_output = self.output(outputs[0].transpose(0, 1))

        return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None)


class LongformerEncoderDecoderForConditionalGeneration(MBartForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
        # if config.attention_mode == 'n2':
        #     pass  # do nothing, use BertSelfAttention instead
        # else:
        for i, layer in enumerate(self.model.encoder.layers):
            layer.self_attn = LongformerSelfAttentionForBart(config, layer_id=i)


class LongformerEncoderDecoderConfig(MBartConfig):
    def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None,
                 autoregressive: bool = False, attention_mode: str = 'sliding_chunks',
                 gradient_checkpointing: bool = False, **kwargs):
        """
        Args:
            attention_window: list of attention window sizes of length = number of layers.
                window size = number of attention locations on each side.
                For an affective window size of 512, use `attention_window=[256]*num_layers`
                which is 256 on each side.
            attention_dilation: list of attention dilation of length = number of layers.
                attention dilation of `1` means no dilation.
            autoregressive: do autoregressive attention or have attention of both sides
            attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
                selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
        """
        super().__init__(**kwargs)
        self.attention_window = attention_window
        self.attention_dilation = attention_dilation
        self.autoregressive = autoregressive
        self.attention_mode = attention_mode
        self.gradient_checkpointing = gradient_checkpointing
        assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2']


logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


def create_long_model(
    save_model_to,
    base_model,
    tokenizer_name_or_path,
    attention_window,
    max_pos
):
    model = MBartForConditionalGeneration.from_pretrained(base_model)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, model_max_length=max_pos)
    config = MBartConfig.from_pretrained(base_model)
    model.config = config

    # in BART attention_probs_dropout_prob is attention_dropout, but LongformerSelfAttention
    # expects attention_probs_dropout_prob, so set it here
    config.attention_probs_dropout_prob = config.attention_dropout
    config.architectures = ['LongformerEncoderDecoderForConditionalGeneration', ]

    # extend position embeddings
    tokenizer.model_max_length = max_pos
    tokenizer.init_kwargs['model_max_length'] = max_pos
    current_max_pos, embed_size = model.model.encoder.embed_positions.weight.shape
    assert current_max_pos == config.max_position_embeddings + 2

    config.max_encoder_position_embeddings = max_pos
    config.max_decoder_position_embeddings = config.max_position_embeddings
    del config.max_position_embeddings
    max_pos += 2  # NOTE: BART has positions 0,1 reserved, so embedding size is max position + 2
    assert max_pos >= current_max_pos

    # allocate a larger position embedding matrix for the encoder
    new_encoder_pos_embed = model.model.encoder.embed_positions.weight.new_empty(max_pos, embed_size)
    # copy position embeddings over and over to initialize the new position embeddings
    k = 2
    step = current_max_pos - 2
    while k < max_pos - 1:
        new_encoder_pos_embed[k:(k + step)] = model.model.encoder.embed_positions.weight[2:]
        k += step
    model.model.encoder.embed_positions.weight.data = new_encoder_pos_embed

    # allocate a larger position embedding matrix for the decoder
    # new_decoder_pos_embed = model.model.decoder.embed_positions.weight.new_empty(max_pos, embed_size)
    # # copy position embeddings over and over to initialize the new position embeddings
    # k = 2
    # step = current_max_pos - 2
    # while k < max_pos - 1:
    #     new_decoder_pos_embed[k:(k + step)] = model.model.decoder.embed_positions.weight[2:]
    #     k += step
    # model.model.decoder.embed_positions.weight.data = new_decoder_pos_embed

    # replace the `modeling_bart.SelfAttention` object with `LongformerSelfAttention`
    config.attention_window = [attention_window] * config.num_hidden_layers
    config.attention_dilation = [1] * config.num_hidden_layers

    for i, layer in enumerate(model.model.encoder.layers):
        longformer_self_attn_for_bart = LongformerSelfAttentionForBart(config, layer_id=i)

        longformer_self_attn_for_bart.longformer_self_attn.query = layer.self_attn.q_proj
        longformer_self_attn_for_bart.longformer_self_attn.key = layer.self_attn.k_proj
        longformer_self_attn_for_bart.longformer_self_attn.value = layer.self_attn.v_proj

        longformer_self_attn_for_bart.longformer_self_attn.query_global = copy.deepcopy(layer.self_attn.q_proj)
        longformer_self_attn_for_bart.longformer_self_attn.key_global = copy.deepcopy(layer.self_attn.k_proj)
        longformer_self_attn_for_bart.longformer_self_attn.value_global = copy.deepcopy(layer.self_attn.v_proj)

        longformer_self_attn_for_bart.output = layer.self_attn.out_proj

        layer.self_attn = longformer_self_attn_for_bart
    logger.info(f'saving model to {save_model_to}')
    model.save_pretrained(save_model_to)
    tokenizer.save_pretrained(save_model_to)
    return model, tokenizer


def main(base_model, tokenizer, save_model_to, attention_window = 512, max_pos = 4096 * 4):
  if not os.path.exists(save_model_to):
    os.mkdir(save_model_to)
  
  model, tokenizer_ = create_long_model(
        save_model_to=save_model_to,
        base_model=base_model,
        tokenizer_name_or_path=tokenizer,
        attention_window=attention_window,
        max_pos=max_pos
    )
  return model, tokenizer

model_, tokenizer_ = main(base_model = 'ARTeLab/mbart-summarization-fanpage', tokenizer = 'ARTeLab/mbart-summarization-fanpage', save_model_to = "/content/model", attention_window = 512, max_pos = 4096 * 4)
  1. Load the Model
from transformers.models.bart.tokenization_bart_fast import BartTokenizerFast
load_model_from = "/content/model"
tokenizer = BartTokenizerFast.from_pretrained(load_model_from)
model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained(load_model_from)
  1. Get the following output

Tokenizer:
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization.
The tokenizer class you load from this checkpoint is ‘MBartTokenizer’.
The class this function is called from is ‘BartTokenizerFast’.

Model
RuntimeError: Error(s) in loading state_dict for LongformerEncoderDecoderForConditionalGeneration:
size mismatch for model.encoder.embed_positions.weight: copying a param with shape torch.Size([16386, 1024]) from checkpoint, the shape in current model is torch.Size([1026, 1024]).

  1. The other loading methods I have tried to call:
  • model = LEDForConditionalGeneration.from_pretrained(load_model_from)
    Error message
    ValueError: The state dictionary of the model you are training to load is corrupted. Are you sure it was properly saved?
  • model = EncoderDecoderModel.from_pretrained(load_model_from)
    Error message
    AssertionError: Config has to be initialized with encoder and decoder config
  • model = LongformerModel.from_pretrained(load_model_from)
    Model is loaded but with warnings
    You are using a model of type mbart to instantiate a model of type longformer. This is not supported for all configurations of models and can yield errors.
    Some weights of the model checkpoint at /content/model were not used when initializing LongformerModel:
    This IS expected if you are initializing LongformerModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    This IS NOT expected if you are initializing LongformerModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

When I try to make inference:

from transformers.models.bart.tokenization_bart_fast import BartTokenizerFast
from transformers.models.bart.modeling_bart import shift_tokens_right

tokenizer = BartTokenizerFast.from_pretrained(load_model_from)
TXT = "an article..."
data = tokenizer([TXT], return_tensors='pt', padding='max_length', max_length=2048)
input_ids = data['input_ids']
attention_mask = data['attention_mask']
decoder_input_ids = shift_tokens_right(input_ids[:, :5], tokenizer.pad_token_id, decoder_start_token_id = 250011)
logits = model.generate(main_input_name = input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False)[0]
masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
probs = logits[0, masked_index].softmax(dim=0)
values, predictions = probs.topk(5)
print(tokenizer.convert_ids_to_tokens(predictions))

Error message
AttributeError: ‘LongformerEncoder’ object has no attribute ‘main_input_name’

or

import torch
text = " ".join(["Hello world! "] * 1000)  # long input document
input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0)  # batch of size 1

attention_mask = torch.ones(
    input_ids.shape, dtype=torch.long, device=input_ids.device
)  # initialize to local attention
global_attention_mask = torch.zeros(
    input_ids.shape, dtype=torch.long, device=input_ids.device
)  # initialize to global attention to be deactivated for all tokens
global_attention_mask[
    :,
    [
        1,
        4,
        21,
    ],
] = 1  # Set global attention to random tokens for the sake of this example
# Usually, set global attention based on the task. For example,
# classification: the <s> token
# QA: question tokens
# LM: potentially on the beginning of sentences and paragraphs
outputs = model(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)

Error message
IndexError: index out of range in self

or

inputs    = tokenizer(TXT, return_tensors="pt")
#inputs = {k: v.cuda() for k, v in inputs.items()}
outputs   = model(**inputs)
features  = outputs[0][:,0,:].detach().numpy().squeeze()
print(tokenizer.decode(features, skip_special_tokens=True, clean_up_tokenization_spaces=True))

Error message
TypeError: ‘float’ object cannot be interpreted as an integer

Expected behavior

Model loaded successfully!
Inference likewise :slight_smile: