Issues running seq2seq distillation

Hello,

I didn’t see these errors earlier when I ran seq2seq distillation last year, however the below script run from transformers/examples/research_projects/seq2seq-distillation gives me a couple of issues.

python distillation.py \
  --teacher google/t5-large-ssm-nq --data_dir $NQOPEN_DIR \
  --tokenizer_name t5-large \
  --student_decoder_layers 6 --student_encoder_layers 6 \
  --freeze_encoder --freeze_embeds \
  --learning_rate=3e-4 \
  --do_train \
   --gpus 4 \
  --do_predict \
  --fp16 --fp16_opt_level=O1 \
  --val_check_interval 0.1 --n_val 500 --eval_beams 2 --length_penalty=0.5 \
  --max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
  --model_name_or_path IGNORED \
  --alpha_hid=3. \
  --train_batch_size=2 --eval_batch_size=2 --gradient_accumulation_steps=2 \
  --sortish_sampler \
  --num_train_epochs=6 \
  --warmup_steps 500 \
  --output_dir distilled_t5_sft \
  --logger_name wandb \
  "$@"

Issues:

  1. I get the following warning at the beginning:
Epoch 0:   0%|          | 2/12396 [00:00<1:20:46,  2.56it/s, loss=nan, v_num=xyme]/home/sumithrab/miniconda3/envs/distill/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:131: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
  1. Wandb does not show learning curves. I get the following warning:
Epoch 0:   1%|          | 99/12396 [00:27<56:02,  3.66it/s, loss=5.55e+04, v_num=xyme]wandb: WARNING Step must only increase in log calls.  Step 98 < 100; dropping {'loss': 56946.98828125, 'ce_loss': 24.933889389038086, 'mlm_loss': 9.203145980834961, 'hid_loss_enc': 951.351318359375, 'hid_loss_dec': 18023.71484375, 'tpb': 42, 'bs': 2, 'src_pad_tok': 2, 'src_pad_frac': 0.05882352963089943}.

Any ideas you may have would be very helpful.

Hi @sbhaktha

  1. You can safely ignore this warning

  2. See if wandb is authorized, you should be able to see logs after first validation run

Hi @valhalla,

wandb is authorized with the API key. I only see the following curves though, even after a couple of validation steps have run-

What does the warning message “wandb: WARNING Step must only increase in log calls. Step 98 < 100; dropping {‘loss’: 56946.98828125, ‘ce_loss’: 24.933889389038086, ‘mlm_loss’: 9.203145980834961, ‘hid_loss_enc’: 951.351318359375, ‘hid_loss_dec’: 18023.71484375, ‘tpb’: 42, ‘bs’: 2, ‘src_pad_tok’: 2, ‘src_pad_frac’: 0.05882352963089943}.” mean?

Thanks,
Sumithra

I am also intrigued to find so many python processes resulting from the above command:

Also, memory usage shows up ~30 gigs against each process. My teacher model is a t5-large with 770 million parameters. That should be about 3 gigs for parameters assuming 32-bit, but in this case fp16 is turned on.The student model is down to 6 layers from teacher’s 24 layers, so roughly 1/3 the size. Batch size is 2. I am not sure what is accounting for the ~30 gigs. Following is the GPU usage-

GPU memory usage is pretty low.

I am probably missing some things. Any thoughts would be appreciated!
Also I still can’t get the learning curves to show up. I forced relogin with wandb and tried using a new API key but that didn’t help.

Thanks,
Sumithra

1 Like

Hi @valhalla ! I’m sort of stuck not being able to get plots… do you have any further thoughts?

Thanks,
Sumithra