The BERT models I have found in the Model’s Hub handle a maximum input length of 512. Using sequences longer than 512 seems to require training the models from scratch, which is time consuming and computationally expensive. However, the only limitation to input sequences longer than 512 in a pretrained BERT model is the length of the position embeddings. Therefore, Would it be okay if I expand the matrix of pretrained position embeddings in order to handle longer sequences and avoid the training from scratch? This can be achieved by re-using the first 512 embeddings to expand the matrix of position embeddings.
I know BERT’s attention mechanism is quadratic and inputting longer sequences requires more computations; however, this seems manageable for sequences below 2k tokens by reducing the batch size. In fact, I tried this expansion of position embeddings using max_position_embeddings=1024
and managed to fine-tune a bert-medium
model using less than 12GB of GPU memory with a batch_size=8
at a rate of 4 minutes per epoch (4k records in the training set). The improvement in accuracy compared to using a max length of 512 seems negligible (~1%); however, having the opportunity to use longer sequences might be useful in applications where truncation is too inconvenient and using larger models such as Longformer and BigBird is difficult in terms of computational resources. I suspect this “hacky” expansion of position embeddings works because the expanded matrix of embeddings still carries the positional information BERT needs. In summary, my quick test suggests the simple expansion of the position embeddings works; however, I am asking to see if anybody has an opinion about any potential issues that may arise with this approach.
Below is the code I used to expand the position embeddings to length 1024 by simply stacking the first 512 positions twice. Note that the position_ids
and token_type_ids
also need to be expanded.
max_length = 1024
model_checkpoint = "prajjwal1/bert-medium"
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
tokenizer.model_max_length = max_length
model.config.max_position_embeddings = max_length
model.base_model.embeddings.position_ids = torch.arange(max_length).expand((1, -1))
model.base_model.embeddings.token_type_ids = torch.zeros(max_length).expand((1, -1))
orig_pos_emb = model.base_model.embeddings.position_embeddings.weight
model.base_model.embeddings.position_embeddings.weight = torch.nn.Parameter(torch.cat((orig_pos_emb, orig_pos_emb)))