I read the FAIR self-training paper and it is very relevant, but I am struggling to understand the specifics of the approach.
Things I understood:
- “use beam search decoding (beam size 5) to create the pseudo targets and to report BLEU on test set.” This outperforms sampling.
- They train with dropout rate at 0.3. This helps.
- They don’t do any cleaning or checking of the pseudolabels against the ground truth target.
- all experiments are run on 8 GPUs with an effective batch size of 33K tokens.
Things I did not understand
In Figure 1 (page 3)
What is the difference between the light shade (pseudo-training) and dark shade (fine-tune) bars?
Section 3.2 tries to explain
(Below is excerpted then markdownd)
In Figure 1, we use green bars to show the result of applying self-training for three iterations. We include both
- (1) pseudo-training (PT): the first step of self-training where we train a new model (from scratch) using only the pseudo parallel data generated by the current model, and
- (2) fine- tuning (FT): the fine-tuned system using real parallel data based on the pretrained model from the PT step.
- Note that in the fine-tuning step the system is re-initialized from scratch.
- Surprisingly, we find that the pseudo-training step at the first iteration is able to improve BLEU even if the model is only trained on its own predictions, and fine-tuning further boosts the performance. The test BLEU keeps improving over the first three iterations, until convergence to outperform the initial baseline by 3 BLEU points.
So when they wrote: “the fine-tuned system using real parallel data based on the pretrained model from the PT step”,
I guess they mean At each interation, the fine-tuned system uses real parallel data and self-training data based on the trained model from the last step
Pseudocode of my understanding:
def fair_self_training_procedure(parallel_data, unlabeled_data, mode = 'pseudo training'):
real_data = 100K pairs of (english-german sentences)
unlabeled data = 3m English sentences
pseudo_dataset = None
# baseline
model = randomly_initialize('transformer')
model = train(model, original_data, dropout=0.3)
baseline_performance = model.evaluate(original_validation_data) # 15.6
pseudo_dataset = (unlabeled_data, model.generate(unlabeled_data, num_beams=5))
if mode == 'fine-tune': # HELP
pseudo_dataset = pseudo_dataset + parallel_data
scores = []
for iteration in range(3): # iteration 1,2,3
model = randomly_initialize('transformer')
# Even in the fine-tuning step the system is re-initialized from scratch.
inject_noise(pseudo_dataset)
model = train(model, pseudo_dataset, dropout=0.3)
scores.append(model.evalute(original_validation_data))
pseudo_dataset = (unlabeled_data, model.generate(unlabeled_data, num_beams=5))
if mode == 'fine-tune': # HELP
pseudo_dataset = pseudo_dataset + parallel_data