Loading Lower Layers of Model

I have trained an ElectraForPreTraining model that has 10 encoder layers and saved the checkpoint. I now want to initialize an ElectraForSequenceClassification from this model. However, I also want to be able to initialize ElectraForSequenceClassification models that only has say 5, or 3 encoder layers that are initialized with the bottom 5 or bottom 3 layers of my ElectraForPreTraining Checkpoint. I was wondering if there is any way to do this with built in library methods. If not, I figure I would have to isolate out the bottom n layers and add a head myself, in which case any help on how to isolate out the bottom n layers into a new model would be extremely helpful.

You could load the full model and a smaller model with less number of layers and then copy the layers from the full model to the smaller model. This snippet might help

import torch.nn as nn
from transformers import ElectraForSequenceClassification, ElectraConfig

# load pre-trained model
model = ElectraForSequenceClassification.from_pretrained("google/electra-small-discriminator")

# create smaller model
config = ElectraConfig.from_pretrained("google/electra-small-discriminator", num_hidden_layers=3)
student = ElectraForSequenceClassification(config)

# this function takes the layers from first model, the layers from smaller model
# and the indices of layers to copy
# and copies the from source layers to dest layers
def copy_layers(src_layers, dest_layers, layers_to_copy):
    layers_to_copy = nn.ModuleList([src_layers[i] for i in layers_to_copy])
    assert len(dest_layers) == len(layers_to_copy), f"{len(dest_layers)} != {len(layers_to_copy)}"
    dest_layers.load_state_dict(layers_to_copy.state_dict())

# to copy last three layers give the indices of last 3 layers
layers_to_copy = [9, 10, 11]

copy_layers(model.electra.encoder.layer, student.electra.encoder.layer, layers_to_copy)

# save the smaller model
student.save_pretrained("student_model")

# now you can load the smaller model using
student = ElectraForSequenceClassification.from_pretrained("student_model")
1 Like