Good command to test examples/seq2seq refactors

Setup:

pip install -r examples/requirements.txt
pip install torch==1.5.1

You can test torch 1.6, but current code does not work well w torch 1.6.

Fetch data for wmt_en_ro

wget https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro.tar.gz
tar -xzvf wmt_en_ro.tar.gz

Finetune 3 layer decoder on wmt_en_ro

export PYTHONPATH="../":"${PYTHONPATH}"
export WANDB_PROJECT=nate
export BS=64
export m=m63=sshleifer/student_marian_en_ro_6_3
export MAX_LEN=128
python finetune.py \
  --learning_rate=3e-4 \
  --do_train \
  --do_predict \
  --fp16 \
  --val_check_interval 0.25 \
  --data_dir wmt_en_ro \
  --max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
  --freeze_encoder --freeze_embeds --num_train_epochs=6 \
  --train_batch_size=$BS --eval_batch_size=$BS \
  --tokenizer_name $m --model_name_or_path $m \
  --warmup_steps 500 --sortish_sampler --logger_name wandb \
  --gpus 1 --fp16_opt_level=O1 --task translation --label_smoothing 0.1 \
  "$@"

This should get val_avg_bleu > 24 within an hour. One epoch should take < 1 hour.

cc @nateraw

2 Likes