How to use encoded hidden_states as input to a Bert/DistilBert Model


I just want to use first 5 layers of distilBert as encoder (not tuning), and only tune the last layer (as code shown, I remove the first 5 layers which were used to generate the hidden states) plus my own model to save memory. Now I got the hidden states from the encoder. Is this a correct way to directly use hidden_state I got from the 5th layer as inputs_embeds for a Bert/DistillBert? Thanks!

hidden_states = torch.load('') ###the output of the 5th layers of distilbert encoder
class Net(nn.Module):
     def forward(hidden_state, mask):
           distilbert_output = self.distilbert(inputs_embeds=hidden_state, attention_mask=mask, 
           hidden_state = distilbert_output[0]                    
           pooled_output = hidden_state[:, 0] 
           x = pooled_output
           self defined classification layers

# Remove unnecessary layers from BERT
model = Net()
num_removed_layers = 1  # Specify the number of layers to remove
encoder_layers = model.distilbert.transformer.layer[-num_removed_layers:]
model.distilbert.transformer.layer = nn.ModuleList(encoder_layers)