Convert models to Longformer

My request was posted as an issue.

Environment info

  • transformers version: 4.2.0
  • Platform: Linux-4.15.0-123-generic-x86_64-with-glibc2.10
  • Python version: 3.8.5
  • PyTorch version (GPU?): 1.7.1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Information

Model I am using script to initialize Longformer starting from HerBERT.

The problem arises when using:

  • [ ] the official example scripts: (give details below)
  • [x] my own modified scripts: (give details below)

The tasks I am working on is:

  • [ ] an official GLUE/SQUaD task: (give the name)
  • [x] my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

  1. Install dependencies: python3 -m pip install -r requirements.txt.
  2. Install apex according to official documentation.
  3. Run command CUDA_VISIBLE_DEVICES=0 python3 convert_model_to_longformer.py --finetune_dataset conllu.

We are using dataset in .jsonl format, each line contains 1 CoNLLu entry. It is converted using custom LineByLineTextDataset class to LineByLine format from current version of transformers. I’ve added this class to be able to use it in older version (v3.0.2).

Using suggested by author on allenai/longformer I’ve used transformers in version 3.0.2 and it works fine. But I would like to use recent models to convert them to Long* version and I can’t make conversion script work.

Result

As a result of running command above with transformers in version 4.2.0 I’ve got:

Traceback (most recent call last):
  File "convert_model_to_longformer.py", line 277, in <module>
    pretrain_and_evaluate(
  File "convert_model_to_longformer.py", line 165, in pretrain_and_evaluate
    eval_loss = trainer.evaluate()
  File "/server/server_1/user/miniconda3/envs/longformer_summary/lib/python3.8/site-packages/transformers/trainer.py", line 1442, in evaluate
    output = self.prediction_loop(
  File "/server/server_1/user/miniconda3/envs/longformer_summary/lib/python3.8/site-packages/transformers/trainer.py", line 1566, in prediction_loop
    loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
  File "/server/server_1/user/miniconda3/envs/longformer_summary/lib/python3.8/site-packages/transformers/trainer.py", line 1670, in prediction_step
    outputs = model(**inputs)
  File "/server/server_1/user/miniconda3/envs/longformer_summary/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/server/server_1/user/miniconda3/envs/longformer_summary/lib/python3.8/site-packages/transformers/models/roberta/modeling_roberta.py", line 1032, in forward
    outputs = self.roberta(
  File "/server/server_1/user/miniconda3/envs/longformer_summary/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/server/server_1/user/miniconda3/envs/longformer_summary/lib/python3.8/site-packages/transformers/models/roberta/modeling_roberta.py", line 798, in forward
    encoder_outputs = self.encoder(
  File "/server/server_1/user/miniconda3/envs/longformer_summary/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/server/server_1/user/miniconda3/envs/longformer_summary/lib/python3.8/site-packages/transformers/models/roberta/modeling_roberta.py", line 498, in forward
    layer_outputs = layer_module(
  File "/server/server_1/user/miniconda3/envs/longformer_summary/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/server/server_1/user/miniconda3/envs/longformer_summary/lib/python3.8/site-packages/transformers/models/roberta/modeling_roberta.py", line 393, in forward
    self_attention_outputs = self.attention(
  File "/server/server_1/user/miniconda3/envs/longformer_summary/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/server/server_1/user/miniconda3/envs/longformer_summary/lib/python3.8/site-packages/transformers/models/roberta/modeling_roberta.py", line 321, in forward
    self_outputs = self.self(
  File "/server/server_1/user/miniconda3/envs/longformer_summary/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "convert_model_to_longformer.py", line 63, in forward
    return super().forward(hidden_states, attention_mask=attention_mask, output_attentions=output_attentions) # v4.2.0
  File "/server/server_1/user/miniconda3/envs/longformer_summary/lib/python3.8/site-packages/transformers/models/longformer/modeling_longformer.py", line 600, in forward
    diagonal_mask = self._sliding_chunks_query_key_matmul(
  File "/server/server_1/user/miniconda3/envs/longformer_summary/lib/python3.8/site-packages/transformers/models/longformer/modeling_longformer.py", line 789, in _sliding_chunks_query_key_matmul
    batch_size, seq_len, num_heads, head_dim = query.size()
ValueError: too many values to unpack (expected 4)

I’ve changed function /server/server_1/user/miniconda3/envs/longformer_summary/lib/python3.8/site-packages/transformers/models/longformer/modeling_longformer.py up to line 789:

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        is_index_masked=None,
        is_index_global_attn=None,
        is_global_attn=None,
        output_attentions=False,
    ):
        """
        :class:`LongformerSelfAttention` expects `len(hidden_states)` to be multiple of `attention_window`. Padding to
        `attention_window` happens in :meth:`LongformerModel.forward` to avoid redoing the padding on each layer.

        The `attention_mask` is changed in :meth:`LongformerModel.forward` from 0, 1, 2 to:

            * -10000: no attention
            * 0: local attention
            * +10000: global attention
        """
        hidden_states = hidden_states.transpose(0, 1)

        # project hidden states
        query_vectors = self.query(hidden_states)
        key_vectors = self.key(hidden_states)
        value_vectors = self.value(hidden_states)
        print(f"query_vectors: {query_vectors.shape}")
        print(f"key_vectors: {key_vectors.shape}")
        print(f"value_vectors: {value_vectors.shape}")
        print(f"attention_mask: {attention_mask.shape}")


        seq_len, batch_size, embed_dim = hidden_states.size()
        assert (
            embed_dim == self.embed_dim
        ), f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}"

        # normalize query
        query_vectors /= math.sqrt(self.head_dim)

        query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
        key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)

        attn_scores = self._sliding_chunks_query_key_matmul(
            query_vectors, key_vectors, self.one_sided_attn_window_size
        )

        # values to pad for attention probs
        remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None]

        # cast to fp32/fp16 then replace 1's with -inf
        float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
            remove_from_windowed_attention_mask, -10000.0
        )

        print(f"attn_scores: {attn_scores.shape}")
        print(f"remove_from_windowed_attention_mask: {remove_from_windowed_attention_mask.shape}")
        print(f"float_mask: {float_mask.shape}")

        # diagonal mask with zeros everywhere and -inf inplace of padding
        diagonal_mask = self._sliding_chunks_query_key_matmul(
            float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
        )

And as a result I’ve got:

attention_mask: torch.Size([2, 1, 1, 1024])
query_vectors: torch.Size([1024, 2, 768])
key_vectors: torch.Size([1024, 2, 768])
value_vectors: torch.Size([1024, 2, 768])
attn_scores: torch.Size([2, 1024, 12, 513])
remove_from_windowed_attention_mask: torch.Size([2, 1, 1, 1, 1, 1024])
float_mask: torch.Size([2, 1, 1, 1, 1, 1024])

And after changing version to 3.0.2 and adding print statements I’ve got:

attention_mask: torch.Size([2, 1024])
query_vectors: torch.Size([1024, 2, 768])
key_vectors: torch.Size([1024, 2, 768])
value_vectors: torch.Size([1024, 2, 768])
attn_scores: torch.Size([2, 1024, 12, 513])
remove_from_windowed_attention_mask: torch.Size([2, 1024, 1, 1])
float_mask: torch.Size([2, 1024, 1, 1])

So maybe it’s problem with _sliding_chunks_query_key_matmul function?

Files:

convert_model_to_longformer.py, based on allenai/longformer/scripts/convert_model_to_long.ipynb:

import logging
import os
import math
import copy
import torch
import argparse
from dataclasses import dataclass, field
from transformers import RobertaForMaskedLM, XLMTokenizer, TextDataset, DataCollatorForLanguageModeling, Trainer, XLMTokenizer, PreTrainedTokenizer
from transformers import TrainingArguments, HfArgumentParser, XLMTokenizer, RobertaModel, XLMTokenizer
from transformers import LongformerSelfAttention # v4.2.0
# from transformers.modeling_longformer import LongformerSelfAttention # v3.0.2
from conllu import load_conllu_dataset, save_conllu_dataset_in_linebyline_format
from torch.utils.data.dataset import Dataset

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)



class LineByLineTextDataset(Dataset):
    """
    This will be superseded by a framework-agnostic approach
    soon.
    """

    def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int):
        assert os.path.isfile(file_path)
        # Here, we do not cache the features, operating under the assumption
        # that we will soon use fast multithreaded tokenizers from the
        # `tokenizers` repo everywhere =)
        logger.info("Creating features from dataset file at %s", file_path)

        with open(file_path, encoding="utf-8") as f:
            lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]

        batch_encoding = tokenizer(
            lines,
            add_special_tokens=True,
            truncation=True,
            padding="max_length",
            max_length=block_size,
            pad_to_multiple_of=512)
        self.examples = batch_encoding["input_ids"]

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i) -> torch.Tensor:
        return torch.tensor(self.examples[i], dtype=torch.long)


class RobertaLongSelfAttention(LongformerSelfAttention):
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        return super().forward(hidden_states, attention_mask=attention_mask, output_attentions=output_attentions)


class RobertaLongForMaskedLM(RobertaForMaskedLM):
    def __init__(self, config):
        super().__init__(config)
        for i, layer in enumerate(self.roberta.encoder.layer):
            # replace the `modeling_bert.BertSelfAttention` object with `LongformerSelfAttention`
            layer.attention.self = RobertaLongSelfAttention(config, layer_id=i)


class RobertaLongModel(RobertaModel):
    def __init__(self, config):
        super().__init__(config)
        for i, layer in enumerate(self.encoder.layer):
            # replace the `modeling_bert.BertSelfAttention` object with `LongformerSelfAttention`
            layer.attention.self = RobertaLongSelfAttention(config, layer_id=i)


def create_long_model(initialization_model, initialization_tokenizer, save_model_to, attention_window, max_pos):
    model = RobertaForMaskedLM.from_pretrained(initialization_model)
    tokenizer = XLMTokenizer.from_pretrained(initialization_tokenizer, model_max_length=max_pos)
    config = model.config

    # extend position embeddings
    tokenizer.model_max_length = max_pos
    tokenizer.init_kwargs['model_max_length'] = max_pos
    current_max_pos, embed_size = model.roberta.embeddings.position_embeddings.weight.shape
    max_pos += 2  # NOTE: RoBERTa has positions 0,1 reserved, so embedding size is max position + 2
    config.max_position_embeddings = max_pos
    assert max_pos > current_max_pos

    # allocate a larger position embedding matrix
    new_pos_embed = model.roberta.embeddings.position_embeddings.weight.new_empty(max_pos, embed_size)

    # copy position embeddings over and over to initialize the new position embeddings
    k = 2
    step = current_max_pos - 2
    while k < max_pos - 1:
        new_pos_embed[k:(k + step)] = model.roberta.embeddings.position_embeddings.weight[2:]
        k += step
    model.roberta.embeddings.position_embeddings.weight.data = new_pos_embed
    model.roberta.embeddings.position_ids.data = torch.tensor([i for i in range(max_pos)]).reshape(1, max_pos) # v4.2.0
    # model.roberta.embeddings.position_ids = torch.tensor([i for i in range(max_pos)]).reshape(1, max_pos) # v3.0.2

    # replace the `modeling_bert.BertSelfAttention` object with `LongformerSelfAttention`
    config.attention_window = [attention_window] * config.num_hidden_layers
    for i, layer in enumerate(model.roberta.encoder.layer):
        longformer_self_attn = LongformerSelfAttention(config, layer_id=i)
        longformer_self_attn.query = copy.deepcopy(layer.attention.self.query)
        longformer_self_attn.key = copy.deepcopy(layer.attention.self.key)
        longformer_self_attn.value = copy.deepcopy(layer.attention.self.value)

        longformer_self_attn.query_global = copy.deepcopy(layer.attention.self.query)
        longformer_self_attn.key_global = copy.deepcopy(layer.attention.self.key)
        longformer_self_attn.value_global = copy.deepcopy(layer.attention.self.value)

        layer.attention.self = longformer_self_attn

    logger.info(f'saving model to {save_model_to}')
    model.save_pretrained(save_model_to)
    tokenizer.save_pretrained(save_model_to)
    return model, tokenizer

def copy_proj_layers(model):
    for i, layer in enumerate(model.roberta.encoder.layer):
        layer.attention.self.query_global = copy.deepcopy(layer.attention.self.query)
        layer.attention.self.key_global = copy.deepcopy(layer.attention.self.key)
        layer.attention.self.value_global = copy.deepcopy(layer.attention.self.value)
    return model

def pretrain_and_evaluate(args, model, tokenizer, eval_only, model_path, max_size):
    val_dataset = LineByLineTextDataset(
        tokenizer=tokenizer,
        file_path=args.val_datapath,
        block_size=max_size,
    )
    if eval_only:
        train_dataset = val_dataset
    else:
        logger.info(f'Loading and tokenizing training data is usually slow: {args.train_datapath}')
        train_dataset = LineByLineTextDataset(
            tokenizer=tokenizer,
            file_path=args.train_datapath,
            block_size=max_size,
        )

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=True,
        mlm_probability=0.15,
    )
    trainer = Trainer(
        model=model,
        args=args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        # prediction_loss_only=True,
    )

    eval_loss = trainer.evaluate()
    eval_loss = eval_loss['eval_loss']
    logger.info(f'Initial eval bpc: {eval_loss/math.log(2)}')

    if not eval_only:
        trainer = Trainer(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            prediction_loss_only=False,
        )

        trainer.train(model_path=model_path)
        trainer.save_model()

        eval_loss = trainer.evaluate()
        eval_loss = eval_loss['eval_loss']
        logger.info(f'Eval bpc after pretraining: {eval_loss/math.log(2)}')

@dataclass
class ModelArgs:
    attention_window: int = field(default=512, metadata={"help": "Size of attention window"})
    max_pos: int = field(default=1024, metadata={"help": "Maximum position"})


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--finetune_dataset", required=True, choices=["conllu"], help="Name of dataset to finetune")
    return parser.parse_args()

if __name__ == "__main__":
    parser = HfArgumentParser((TrainingArguments, ModelArgs,))
    args = parse_args()

    training_args, model_args = parser.parse_args_into_dataclasses(look_for_args_file=False, args=[
        '--output_dir', 'tmp_4.2.0',
        '--warmup_steps', '500',
        '--learning_rate', '0.00003',
        '--weight_decay', '0.01',
        '--adam_epsilon', '1e-6',
        '--max_steps', '3000',
        '--logging_steps', '500',
        '--save_steps', '500',
        '--max_grad_norm', '5.0',
        '--per_device_eval_batch_size', '2',
        '--per_device_train_batch_size', '2',
        '--gradient_accumulation_steps', '4',
        # '--evaluate_during_training',
        '--do_train',
        '--do_eval',
        '--fp16',
        '--fp16_opt_level', 'O2',
    ])

    if args.finetune_dataset == "conllu":
        saved_dataset = '/server/server_1/user/longformer_summary/conllu/'
        if not os.path.exists(saved_dataset):
            os.makedirs(saved_dataset)
            dataset = load_conllu_dataset('/server/server_1/user/conllu_dataset/')
            save_conllu_dataset_in_linebyline_format(dataset, saved_dataset)

    training_args.val_datapath = os.path.join(saved_dataset, 'validation.txt')
    training_args.train_datapath = os.path.join(saved_dataset, 'train.txt')

    initialization_model = 'allegro/herbert-klej-cased-v1'
    initialization_tokenizer = 'allegro/herbert-klej-cased-tokenizer-v1'

    roberta_base = RobertaForMaskedLM.from_pretrained(initialization_model)
    roberta_base_tokenizer = XLMTokenizer.from_pretrained(initialization_tokenizer, model_max_length=512)

    model_path = f'{training_args.output_dir}/{initialization_model}-{model_args.max_pos}'
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    logger.info(f'Converting roberta-base into {initialization_model}-{model_args.max_pos}')
    model, tokenizer = create_long_model(
        initialization_model=initialization_model,
        initialization_tokenizer=initialization_tokenizer,
        save_model_to=model_path,
        attention_window=model_args.attention_window,
        max_pos=model_args.max_pos,
    )

    logger.info(f'Loading the model from {model_path}')
    tokenizer = XLMTokenizer.from_pretrained(model_path)
    model = RobertaLongForMaskedLM.from_pretrained(model_path)

    logger.info(f'Pretraining {initialization_model}-{model_args.max_pos} ... ')

    pretrain_and_evaluate(
        training_args,
        model,
        tokenizer,
        eval_only=False,
        model_path=training_args.output_dir,
        max_size=model_args.max_pos,
    )

    logger.info(f'Copying local projection layers into global projection layers... ')
    model = copy_proj_layers(model)
    logger.info(f'Saving model to {model_path}')
    model.save_pretrained(model_path)

    logger.info(f'Loading the model from {model_path}')
    tokenizer = XLMTokenizer.from_pretrained(model_path)
    model = RobertaLongModel.from_pretrained(model_path)

conllu.py

import re
import glob
import torch
from torch.utils.data import Dataset
import time
import os
import json
from xml.etree.ElementTree import ParseError
import xml.etree.ElementTree as ET
from typing import List, Dict
from sklearn.model_selection import train_test_split


def load_conllu_jsonl(
    path: str,
) -> List[Dict[str, str]]:
    dataset: List[Dict[str, str]] = list()
    with open(path, 'r') as f:
        for jsonl in f.readlines():
            json_file = json.loads(jsonl)
            conllu = json_file['conllu'].split('\n')
            doc_text: str = ""
            utterance: Dict[str, str] = dict()
            for line in conllu:
                try:
                    if line[0].isdigit():
                        if utterance:
                            masked_text = utterance["text"]
                            doc_text = f"{doc_text} {masked_text}.".strip()
                            utterance = dict()
                    elif line[0] == '#':
                        text = line[1:].strip()
                        key = text.split('=')[0].strip()
                        value = text.split('=')[1].strip()
                        utterance[key] = value
                except IndexError:
                    pass

            dataset.append({"text": doc_text})

    return dataset


def load_conllu_dataset(
    path: str,
    train_test_val_ratio: float = 0.1,
) -> Dict[str, List[Dict[str, str]]]:
    dataset: Dict[str, List[Dict[str, str]]] = dict()
    data_dict: Dict[str, List[str]] = dict()

    filepath_list = glob.glob(os.path.join(path, '*.jsonl'))

    train = filepath_list[:int(len(filepath_list)*0.8)]
    test = filepath_list[int(len(filepath_list)*0.8):int(len(filepath_list)*0.9)]
    val = filepath_list[int(len(filepath_list)*0.9):]

    data_dict["test"] = test
    data_dict["train"] = train
    data_dict["validation"] = val

    for key, value in data_dict.items():
        dataset_list: List[Dict[str, str]] = list()
        for filepath in value:
            data = load_conllu_jsonl(path=filepath)
            if data:
                dataset_list.extend(data)

        dataset[key] = dataset_list

    return dataset

def save_conllu_dataset_in_linebyline_format(
    dataset: Dict[str, List[Dict[str, str]]],
    save_dir: str,
) -> None:
    for key, value in dataset.items():
        with open(os.path.join(save_dir, f'{key}.txt'), 'w') as f:
            for line in value:
                # print(line["full"])
                f.write(f'{line["text"]}\n')

requirements.txt:

apex @ file:///server/server_1/user/apex
certifi==2020.12.5
chardet==4.0.0
click==7.1.2
datasets==1.2.0
dill==0.3.3
filelock==3.0.12
idna==2.10
joblib==1.0.0
multiprocess==0.70.11.1
numpy==1.19.4
packaging==20.8
pandas==1.2.0
pyarrow==2.0.0
pyparsing==2.4.7
python-dateutil==2.8.1
pytz==2020.5
regex==2020.11.13
requests==2.25.1
sacremoses==0.0.43
sentencepiece==0.1.94
six==1.15.0
tokenizers==0.8.1rc1
torch==1.7.1
tqdm==4.49.0
transformers==3.0.2
typing-extensions==3.7.4.3
urllib3==1.26.2
xxhash==2.0.0

Expected behavior

Model should be converted, saved and loaded. After that it should be properly fine-tuned and saved on disk.

Comparing codebase of version 3.0.2 and 4.2.0 I have noticed that forward function differs. I have added deleted lines right at the beginning of the function:

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        is_index_masked=None,
        is_index_global_attn=None,
        is_global_attn=None,
        output_attentions=False,
    ):
        """
        :class:`LongformerSelfAttention` expects `len(hidden_states)` to be multiple of `attention_window`. Padding to
        `attention_window` happens in :meth:`LongformerModel.forward` to avoid redoing the padding on each layer.

        The `attention_mask` is changed in :meth:`LongformerModel.forward` from 0, 1, 2 to:

            * -10000: no attention
            * 0: local attention
            * +10000: global attention
        """
        attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1)

        # is index masked or global attention
        is_index_masked = attention_mask < 0
        is_index_global_attn = attention_mask > 0
        is_global_attn = any(is_index_global_attn.flatten())

and now model seems to be working, but returns:

{'eval_loss': nan, 'eval_runtime': 20.6319, 'eval_samples_per_second': 1.939}

Below You can find results of consecutive steps in forward function. Can You see something wrong here?

diagonal_mask: tensor([[[[-inf, -inf, -inf,  ..., 0., 0., 0.]],

         [[-inf, -inf, -inf,  ..., 0., 0., 0.]],

         [[-inf, -inf, -inf,  ..., 0., 0., 0.]],

         ...,

         [[0., 0., 0.,  ..., -inf, -inf, -inf]],

         [[0., 0., 0.,  ..., -inf, -inf, -inf]],

         [[0., 0., 0.,  ..., -inf, -inf, -inf]]],


        [[[-inf, -inf, -inf,  ..., 0., 0., 0.]],

         [[-inf, -inf, -inf,  ..., 0., 0., 0.]],

         [[-inf, -inf, -inf,  ..., 0., 0., 0.]],

         ...,

         [[0., 0., 0.,  ..., -inf, -inf, -inf]],

         [[0., 0., 0.,  ..., -inf, -inf, -inf]],

         [[0., 0., 0.,  ..., -inf, -inf, -inf]]]], device='cuda:0',
       dtype=torch.float16)
attn_scores: tensor([[[[   -inf,    -inf,    -inf,  ...,  0.5771,  0.2065, -1.0449],
          [   -inf,    -inf,    -inf,  ..., -1.3174, -1.5547, -0.6240],
          [   -inf,    -inf,    -inf,  ..., -1.3691, -1.3555, -0.3799],
          ...,
          [   -inf,    -inf,    -inf,  ...,  1.7402,  1.6152,  0.8242],
          [   -inf,    -inf,    -inf,  ...,  0.5122,  1.0342,  0.2091],
          [   -inf,    -inf,    -inf,  ...,  1.7568, -0.1534,  0.7505]],

         [[   -inf,    -inf,    -inf,  ..., -0.8066, -1.7480, -2.5527],
          [   -inf,    -inf,    -inf,  ..., -3.3652,  0.1046, -0.5811],
          [   -inf,    -inf,    -inf,  ..., -0.0958, -1.0957, -0.2377],
          ...,
          [   -inf,    -inf,    -inf,  ..., -0.4148, -0.9497, -0.1229],
          [   -inf,    -inf,    -inf,  ..., -1.9443, -1.3467, -1.5342],
          [   -inf,    -inf,    -inf,  ...,  0.1263, -0.4407,  0.1486]],

         [[   -inf,    -inf,    -inf,  ..., -0.9077, -0.1603, -0.5762],
          [   -inf,    -inf,    -inf,  ..., -0.2454,  0.1932, -0.5034],
          [   -inf,    -inf,    -inf,  ..., -1.4375, -1.2793, -1.0488],
          ...,
          [   -inf,    -inf,    -inf,  ..., -0.3452,  0.1405,  1.3643],
          [   -inf,    -inf,    -inf,  ..., -0.2168, -1.0000, -0.9956],
          [   -inf,    -inf,    -inf,  ..., -1.7451,  0.1410, -0.6221]],

         ...,

         [[-1.3965,  0.7798,  0.4707,  ...,    -inf,    -inf,    -inf],
          [ 0.6260, -0.4146,  0.9180,  ...,    -inf,    -inf,    -inf],
          [ 0.4807, -1.0742,  1.2803,  ...,    -inf,    -inf,    -inf],
          ...,
          [ 0.0909,  0.8022, -0.4170,  ...,    -inf,    -inf,    -inf],
          [-2.6035, -1.2988,  0.5586,  ...,    -inf,    -inf,    -inf],
          [-0.6953, -0.8232,  0.0436,  ...,    -inf,    -inf,    -inf]],

         [[ 1.0889, -0.2776, -0.0632,  ...,    -inf,    -inf,    -inf],
          [-0.4128,  0.4834, -0.3848,  ...,    -inf,    -inf,    -inf],
          [-0.8794,  0.9150, -1.5107,  ...,    -inf,    -inf,    -inf],
          ...,
          [ 0.8867, -0.4731,  0.3389,  ...,    -inf,    -inf,    -inf],
          [-0.1365,  0.4905, -2.0000,  ...,    -inf,    -inf,    -inf],
          [-0.0205, -0.5464, -0.6851,  ...,    -inf,    -inf,    -inf]],

         [[    nan,     nan,     nan,  ...,    -inf,    -inf,    -inf],
          [    nan,     nan,     nan,  ...,    -inf,    -inf,    -inf],
          [    nan,     nan,     nan,  ...,    -inf,    -inf,    -inf],
          ...,
          [    nan,     nan,     nan,  ...,    -inf,    -inf,    -inf],
          [    nan,     nan,     nan,  ...,    -inf,    -inf,    -inf],
          [    nan,     nan,     nan,  ...,    -inf,    -inf,    -inf]]],


        [[[   -inf,    -inf,    -inf,  ..., -4.0469, -2.6270, -5.4805],
          [   -inf,    -inf,    -inf,  ..., -0.9312, -0.6743, -1.9688],
          [   -inf,    -inf,    -inf,  ..., -0.0593, -0.9507, -0.6392],
          ...,
          [   -inf,    -inf,    -inf,  ...,  0.3105,  2.3926,  1.0664],
          [   -inf,    -inf,    -inf,  ..., -0.0166,  2.2754,  1.0449],
          [   -inf,    -inf,    -inf,  ..., -0.4224,  1.7686, -0.2603]],

         [[   -inf,    -inf,    -inf,  ..., -0.5088, -1.2666, -0.4363],
          [   -inf,    -inf,    -inf,  ..., -0.3823, -1.7998, -0.4504],
          [   -inf,    -inf,    -inf,  ..., -0.1525,  0.1614, -0.0267],
          ...,
          [   -inf,    -inf,    -inf,  ...,  0.0225, -0.5737,  0.2318],
          [   -inf,    -inf,    -inf,  ...,  0.7139,  0.6099,  0.3767],
          [   -inf,    -inf,    -inf,  ...,  0.2008, -0.6714,  0.5869]],

         [[   -inf,    -inf,    -inf,  ..., -0.9302, -1.5303, -2.7637],
          [   -inf,    -inf,    -inf,  ..., -0.1124, -0.5850,  0.0818],
          [   -inf,    -inf,    -inf,  ..., -1.5176, -1.7822, -0.9111],
          ...,
          [   -inf,    -inf,    -inf,  ..., -0.3618,  0.3486,  0.4368],
          [   -inf,    -inf,    -inf,  ..., -0.4158, -1.1660, -0.9106],
          [   -inf,    -inf,    -inf,  ..., -0.4636, -0.7012, -0.9570]],

         ...,

         [[-1.0137, -1.2324, -0.2091,  ...,    -inf,    -inf,    -inf],
          [ 0.0793,  0.1862, -0.6162,  ...,    -inf,    -inf,    -inf],
          [ 0.2406,  0.1237, -1.0420,  ...,    -inf,    -inf,    -inf],
          ...,
          [ 0.5308,  0.3862,  0.9731,  ...,    -inf,    -inf,    -inf],
          [-0.5752, -0.8174,  0.4766,  ...,    -inf,    -inf,    -inf],
          [-0.4299, -0.7031, -0.6240,  ...,    -inf,    -inf,    -inf]],

         [[-2.9512, -1.0410,  0.9194,  ...,    -inf,    -inf,    -inf],
          [-0.0306, -0.8579,  0.1930,  ...,    -inf,    -inf,    -inf],
          [ 0.2927, -1.4600, -1.6787,  ...,    -inf,    -inf,    -inf],
          ...,
          [ 0.6128, -0.8921,  1.2861,  ...,    -inf,    -inf,    -inf],
          [-0.7778, -0.8564,  2.3457,  ...,    -inf,    -inf,    -inf],
          [-0.8877, -1.4834,  0.7783,  ...,    -inf,    -inf,    -inf]],

         [[    nan,     nan,     nan,  ...,    -inf,    -inf,    -inf],
          [    nan,     nan,     nan,  ...,    -inf,    -inf,    -inf],
          [    nan,     nan,     nan,  ...,    -inf,    -inf,    -inf],
          ...,
          [    nan,     nan,     nan,  ...,    -inf,    -inf,    -inf],
          [    nan,     nan,     nan,  ...,    -inf,    -inf,    -inf],
          [    nan,     nan,     nan,  ...,    -inf,    -inf,    -inf]]]],
       device='cuda:0', dtype=torch.float16)

Hey @adamwawrzynski

The implem of Lonformer is in 4.2.0 is different from 3.0.2, you might need to modify the convert script for the new version.

Hi @valhalla,

You are right, but I would be really grateful for in this. I have pinpointed places in code gives different results but my knowledge of transformers is limited and I don’t know what was changed since version 3.0.2.

I have managed to run RoBERTa with Longformer attention, PR: Update conversion script to transformers v4.2.0 by adamwawrzynski · Pull Request #166 · allenai/longformer · GitHub.