Hello,
There are times when I need to fine-tune the output head of the pre-trained Transformer models, for instance, the multiple-choice head of the GPT2DoubleHeadsModel
is not pre-trained, so I need to fine-tune the multiple-choice head, although the weights for the main body of the model are okay. People are telling me that I am supposed to set my learning rate fairly high for the weights of the output head, while setting them low for the weights of the main body, but I don’t know how to do that.
The way I train the pre-trained transformer models is like below:
def train(model, train_data, optimizer, scheduler, log_interval, pad_index):
# turn on a training mode
model.train()
# initialize total_loss to 0
total_loss = 0
for batch_index, batch in enumerate(train_data):
gc.collect()
input_ids = [instance for instance in batch.text]
optimizer.zero_grad()
input_ids = torch.tensor([input_ids], dtype=torch.long)
loss = model(input_ids, labels = input_ids)[0]
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
total_loss = total_loss + loss
if batch_index % log_interval == 0 and batch_index > 0:
cur_loss = total_loss / log_interval
print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.9f} | loss {:5.4f} | ppl {:8.4f}'.format(
epoch, batch_index, len(train_iter), scheduler.get_lr()[0], cur_loss, math.exp(cur_loss)))
total_loss = 0
del input_ids
del loss
gc.collect()
optimizer = AdamW(model.parameters(), lr = 0.000005200, correct_bias = True)
scheduler = get_constant_schedule(optimizer = optimizer, last_epoch = -1)
for epoch in range(1, epochs + 1):
gc.collect()
epoch_start_time = time.time()
train(model, train_data,
optimizer, scheduler,
log_int, pad_index)
print('-' * 89)
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.4f} | '
'valid ppl {:8.4f}'.format(epoch, (time.time() - epoch_start_time),
val_loss, math.exp(val_loss)))
print('-' * 89)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model = model
scheduler.step()
What is the better way for fine-tuning the output-head?
Thank you,