Fine-tuning BERT with sequences longer than 512 tokens

The BERT models I have found in the :hugs: Model’s Hub handle a maximum input length of 512. Using sequences longer than 512 seems to require training the models from scratch, which is time consuming and computationally expensive. However, the only limitation to input sequences longer than 512 in a pretrained BERT model is the length of the position embeddings. Therefore, Would it be okay if I expand the matrix of pretrained position embeddings in order to handle longer sequences and avoid the training from scratch? This can be achieved by re-using the first 512 embeddings to expand the matrix of position embeddings.

I know BERT’s attention mechanism is quadratic and inputting longer sequences requires more computations; however, this seems manageable for sequences below 2k tokens by reducing the batch size. In fact, I tried this expansion of position embeddings using max_position_embeddings=1024 and managed to fine-tune a bert-medium model using less than 12GB of GPU memory with a batch_size=8 at a rate of 4 minutes per epoch (4k records in the training set). The improvement in accuracy compared to using a max length of 512 seems negligible (~1%); however, having the opportunity to use longer sequences might be useful in applications where truncation is too inconvenient and using larger models such as Longformer and BigBird is difficult in terms of computational resources. I suspect this “hacky” expansion of position embeddings works because the expanded matrix of embeddings still carries the positional information BERT needs. In summary, my quick test suggests the simple expansion of the position embeddings works; however, I am asking to see if anybody has an opinion about any potential issues that may arise with this approach.

Below is the code I used to expand the position embeddings to length 1024 by simply stacking the first 512 positions twice. Note that the position_ids and token_type_ids also need to be expanded.

max_length = 1024
model_checkpoint = "prajjwal1/bert-medium"
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

tokenizer.model_max_length = max_length
model.config.max_position_embeddings = max_length
model.base_model.embeddings.position_ids = torch.arange(max_length).expand((1, -1))
model.base_model.embeddings.token_type_ids = torch.zeros(max_length).expand((1, -1))
orig_pos_emb = model.base_model.embeddings.position_embeddings.weight
model.base_model.embeddings.position_embeddings.weight = torch.nn.Parameter(torch.cat((orig_pos_emb, orig_pos_emb)))
1 Like

why not use a longformer?

In my experience, LongFormer and BigBird require a lot of GPU memory. I tried using these on a 14GB GPU, but I was limited to batch_size=1, which took for ever to train and yielded rather poor results.

Interesting, was just thinking that if a bert model is passed in longer sequences of tokens. But the bert model was initially pretrained with a limited (512) tokens. Wouldn’t the weights be “confused” by the longer sequences. hence perform badly?

Perhaps, but I assume the “confusion” will not be significant, given that the only element in BERT’s architecture that depends on the sequence length is the position embeddings. My empirical test on an IMDB text classification task suggests that using longer sequences does not degrade the performance, but it does not significantly improve it either (see the original post where I mention that performance only increased by 1% when using longer sequences). However, for other datasets and tasks, the performance may vary when using longer sequences.

1 Like

do you know if the max seq length is 512 character token or word token?

HI @moma1820, I saw you posted a similar question in another thread, so I replied there. See the link below:

1 Like

thanks man!

1 Like