How to fine-tune the output head of the pre-trained Transformer models?

There are times when I need to fine-tune the output head of the pre-trained Transformer models, for instance, the multiple-choice head of the GPT2DoubleHeadsModel is not pre-trained, so I need to fine-tune the multiple-choice head, although the weights for the main body of the model are okay. People are telling me that I am supposed to set my learning rate fairly high for the weights of the output head, while setting them low for the weights of the main body, but I don’t know how to do that.

The way I train the pre-trained transformer models is like below:

def train(model, train_data, optimizer, scheduler, log_interval, pad_index):

    # turn on a training mode
    # initialize total_loss to 0
    total_loss = 0

    for batch_index, batch in enumerate(train_data):
        input_ids = [instance for instance in batch.text]

        input_ids = torch.tensor([input_ids], dtype=torch.long)
        loss = model(input_ids, labels = input_ids)[0]
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        total_loss = total_loss + loss 

        if batch_index % log_interval == 0 and batch_index > 0:
            cur_loss = total_loss / log_interval
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.9f} | loss {:5.4f} | ppl {:8.4f}'.format(
                    epoch, batch_index, len(train_iter), scheduler.get_lr()[0], cur_loss, math.exp(cur_loss)))
            total_loss = 0 
        del input_ids
        del loss

optimizer = AdamW(model.parameters(), lr = 0.000005200, correct_bias = True)
scheduler = get_constant_schedule(optimizer = optimizer, last_epoch = -1)

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(model, train_data, 
                  optimizer, scheduler, 
                  log_int, pad_index)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.4f} | '
          'valid ppl {:8.4f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = model

What is the better way for fine-tuning the output-head?

Thank you,