Setup:
pip install -r examples/requirements.txt
pip install torch==1.5.1
You can test torch 1.6, but current code does not work well w torch 1.6.
Fetch data for wmt_en_ro
wget https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro.tar.gz
tar -xzvf wmt_en_ro.tar.gz
Finetune 3 layer decoder on wmt_en_ro
export PYTHONPATH="../":"${PYTHONPATH}"
export WANDB_PROJECT=nate
export BS=64
export m=m63=sshleifer/student_marian_en_ro_6_3
export MAX_LEN=128
python finetune.py \
--learning_rate=3e-4 \
--do_train \
--do_predict \
--fp16 \
--val_check_interval 0.25 \
--data_dir wmt_en_ro \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--freeze_encoder --freeze_embeds --num_train_epochs=6 \
--train_batch_size=$BS --eval_batch_size=$BS \
--tokenizer_name $m --model_name_or_path $m \
--warmup_steps 500 --sortish_sampler --logger_name wandb \
--gpus 1 --fp16_opt_level=O1 --task translation --label_smoothing 0.1 \
"$@"
This should get val_avg_bleu > 24 within an hour. One epoch should take < 1 hour.
cc @nateraw