Training a model with custom attention masks in each layer

Hey team,

I’m trying to train a model (using BERT for now but would like to use others (encoder/decoder)) with a custom attention mask. The attention masks have differences based on the level of the layer as well as the token’s position. A simple scheme is to have the normal dense self attention for the first k layers and all latter layers use sparse attention.

I was wondering if there are any simple ways of accomplishing this with a pre-trained model. Since, this is a change in logic at the level of the forward pass of the model, do I need to inherit the corresponding Model Class for e.g. BertForSequenceClassification and override the relevant functions or is there an easier way?