Partially fine-tuning an encoder in an encoder-decoder transformer

Hi. I’m doing an application on encoder-decoder network. I’ve frozen the encoder the fine-tuned a GPT-2 decoder. This can be simply achieved with GPT2Config.add_cross_attention = True and by passing the encoder_hidden_states to the forward function of GPT2. When preparing the model using accelerate, I only send the decoder, and I externally extract the encoder features. My code looks like this during training:

config = GPT2Config()
config.add_cross_attention = True
model = GPT2LMHeadModel.from_pretrained('gpt2', config = config)
model =
model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)

for epoch in range(epochs):
    for i, batch in enumerate(train_loader):

        batch = tuple( for input_tensor in batch)
        encoder_input, input_ids, segment_ids = batch

        # my_encoder is not sent to prepare function, and frozen and ran with @torch.no_grad
        encoder_outputs = my_encoder(encoder_input)
        outputs = model(input_ids=input_ids, 
        loss = outputs.loss

If I want to fine-tune partially the my_encoder (let’s say last block only), I must send the whole my_encoder to the prepare function? Or should I break it up into two parts, the first which is not trainable (not send to prepare ) and the second which will be trained (which will be sent to prepare )? Ofcourse any option requires modifying the GPT2LMHeadModel to add the trained part of the encoder, or to add the trainable part to the optimizer parameters.