What I know and don't know about sequence to sequence batching

The old code in examples/seq2seq/finetune.py:

decoder_input_ids = labels[:, 1:]

deletes a token if there is no bos at the start of the tensor.

The new code

decoder_input_ids = shift_tokens_right(labels)
  • Doesn’t delete tokens.
  • helps metrics of fine-tuned mbart and pegasus (don’t use BOS)
  • does not change metrics of finetuned models that use bos.

This makes sense. Deleting the first word of every tgt example makes finetuning worse.

What I don’t understand:

bart’s shift_tokens_right wraps eos token around to the 0th position of decoder_input_ids.
This means that models are finetuned with the equivalent of decoder_start_token_id=eos_token_id.
However, when it comes time to evaluate mbart/pegasus/marian, having decoder_start_token_id=pad_token_id produces better metrics than decoder_start_token_id=eos_token_id. For the bart variants, decoder_start_token_id=eos_token_id works best.
Additionally, switch to the t5 shift_tokens_right functionality, which puts decoder_start_token_id at position 0 for finetuning doesn’t improve metrics at all.

Does this make sense? What am I missing?

1 Like

@sshleifer
Isn’t it necessary to add a start token to the decoder sequence (be it anything eos, bos, pad or whatever else) when training generative models, because the labels need to be shifted to calculate loss, without shifting the labels the models will only learn to produce the last token it has seen.
Is this right ?

And I think MBart and PEGASUS don’t have bos and may have been pre-trained by using pad token (like T5) as the first token which could explain this behavior. I couldn’t find this info in PEGASUS repo. What was the decoder_start_token_id when pre-training PEGASUS and MBart ?

  • You don’t need to add bos through the tokenizer. shift_tokens_right will put eos at the first position for you. (shift_tokens_right in modeling_bart.py makes eos the first token always.)
  • That’s also how it’s done in fairseq, where mbart was trained.
  • Pegasus was probably trained with <pad> at the first position, but seems to finetune well enough with <eos> at the first position.

Does that answer your Q?

contents of gist explaining fairseq batching:

During training, fairseq passes mbart dynamically sized batches (up to 128 tokens), in a dict called sample with the following relevant keys:

  • target (our labels): no bos, ends with [2, tgt_lang_code]
  • net_input.src_tokens (our input_ids): ends with [2, 250004]
  • net_input.prev_output_tokens (our decoder_input_ids): startswith 250020, ends with 2 . This is the “shift_tokens_right” version of target.

Logs

Here are the logs from my breakpoint:

ipdb> sample.keys()
dict_keys(['id', 'nsentences', 'ntokens', 'net_input', 'target'])
ipdb> sample['net_input'].keys()
dict_keys(['src_tokens', 'src_lengths', 'prev_output_tokens'])
ipdb> sample['target'][0]
tensor([  9345,    202,     10, 181684,     36,  21635,   8454,  48993,  45587,
            21,  57476,   1283,  98748,    451,    346,   8916,    202,     28,
             9,      7,    451,  11650, 128402,      5,      2, 250020],
       device='cuda:0')
ipdb> sample['net_input']['src_tokens'][0]
tensor([   581,   4738,  30666,    297,     10,  21635,   1363,     98,  28811,
           552,      9,  21473,   1363,     23,     70,     28,      9,  94005,
          8916,      5,      2, 250004], device='cuda:0')
ipdb> sample['net_input']['prev_output_tokens'][0]
tensor([250020,   9345,    202,     10, 181684,     36,  21635,   8454,  48993,
         45587,     21,  57476,   1283,  98748,    451,    346,   8916,    202,
            28,      9,      7,    451,  11650, 128402,      5,      2],
       device='cuda:0')

Command

(first, wget fairseq_wmt_enro.tgz from s3)


 fairseq-train fairseq_wmt_enro  --encoder-normalize-before --decoder-normalize-before  --arch mbart_large \
 --task translation_from_pretrained_bart  --source-lang $SRC --target-lang $TGT --criterion label_smoothed_cross_entropy \
 --label-smoothing 0.2  --dataset-impl mmap --optimizer adam \
 --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
--lr-scheduler polynomial_decay --lr 3e-4 --min-lr -1 \
--warmup-updates 2500 --total-num-update 300000 --dropout 0.2 --attention-dropout 0.1 \
--weight-decay 0.0 --max-tokens 128 --update-freq 2 --save-interval 1 --save-interval-updates 5000 \
--keep-interval-updates 3 --no-epoch-checkpoints --seed 222 \
--log-interval 2 --reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler \
--restore-file $PRETRAIN --langs $langs --layernorm-embedding  \
--ddp-backend no_c10d --save-dir $DEST --memory-efficient-fp16