Distillation: create student model from a different base model than teacher


The current implementation of distillation in examples/seq2seq/distillation.py creates a student model by copying selected layers from the teacher model. However I am interested in creating a student model from a different base model, for e.g., teacher model using t5-large and student model using t5-small. I have made changes here. I think I am missing something, because when I try to run this using:

python distillation.py --teacher t5-large --data_dir $NQOPEN_DIR \
--student_base_model t5-small --tokenizer_name t5-small \
--learning_rate=3e-4 --freeze_encoder --freeze_embeds \
--do_train --train_batch_size 32 \
--do_predict --n_train 10 \
--model_name_or_path t5-small --eval_beams 2 --eval_max_gen_length 142 \
--val_check_interval 0.25 --n_val 10 \
--output_dir distilt5 --gpus 1 --logger_name wandb

I get the following error. Could you please let me know what I am missing?

Traceback (most recent call last):
  File "distillation.py", line 361, in <module>
  File "distillation.py", line 352, in distill_main
    return ft_main(args, model=model)
  File "/home/sumithrab/transformers/examples/seq2seq/finetune.py", line 407, in main
    trainer: pl.Trainer = generic_train(
  File "/home/sumithrab/transformers/examples/lightning_base.py", line 382, in generic_train
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1003, in fit
    results = self.single_gpu_train(model)
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/pytorch_lightning/trainer/distrib_parts.py", line 186, in single_gpu_train
    results = self.run_pretrain_routine(model)
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1193, in run_pretrain_routine
    eval_results = self._evaluate(model,
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 293, in _evaluate
    output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 470, in evaluation_forward
    output = model.validation_step(*args)
  File "/home/sumithrab/transformers/examples/seq2seq/finetune.py", line 181, in validation_step
    return self._generative_step(batch)
  File "/home/sumithrab/transformers/examples/seq2seq/finetune.py", line 225, in _generative_step
    loss_tensors = self._step(batch)
  File "distillation.py", line 211, in _step
    outputs = self.teacher(
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sumithrab/transformers/src/transformers/modeling_t5.py", line 1201, in forward
    decoder_outputs = self.decoder(
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sumithrab/transformers/src/transformers/modeling_t5.py", line 757, in forward
    layer_outputs = layer_module(
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sumithrab/transformers/src/transformers/modeling_t5.py", line 547, in forward
    cross_attention_outputs = self.layer[1](
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sumithrab/transformers/src/transformers/modeling_t5.py", line 469, in forward
    attention_output = self.EncDecAttention(
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sumithrab/transformers/src/transformers/modeling_t5.py", line 356, in forward
    k = shape(self.k(k))  # (bs, n_heads, qlen, dim_per_head)
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 91, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/functional.py", line 1676, in linear
    output = input.matmul(weight.t())
RuntimeError: mat1 dim 1 must match mat2 dim 0

Also, I am not sure if the --model_name_or_path and the --tokenizer_name arguments are correct-- should they be t5-large or t5-small?


It might not work since t5-large and t5-small have different hiddden_size, attention_head, feed-forward dim, the layers are sized differently, so it’s not possible to copy weights when the sizes don’t match.

cc. @sshleifer

My code changes do two things–

  1. Create student model directly from the specified base model and not copy layers from the teacher.
  2. Not try to calculate hidden loss if any of the hidden layer sizes are different.

Is there something else I need to be doing? Any ideas on how to get past this error?

@sshleifer: would appreciate your insight if you get a chance…

What suraj said is correct. We haven’t really tried this. What’s happening is you are calling the teacher decoder on the student encoder outputs.

Two options for you:
(1) If you pass --no_teacher (or just finetune.py your student), things should work.
(2) Otherwise (if you want to supervisee student logits ~= teacher logits) you need to roll up your sleeves and edit SummarizationDistiller.step to run the teacher encoder and student encoder.

I would recommend (1). (2) will be much slower and I suspect not much better performing.

Share your results!

Thanks @sshleifer. I do need a teacher. I want to use teacher logits to supervise a student, so need option 2. Regarding using student encoder outputs while calling teacher decoder, I am currently not doing that.
I am getting the teacher encoder outputs and passing to the teacher decoder like this–

            with torch.no_grad():
                teacher_enc_outputs, teacher_enc_hid = self.teacher.get_encoder()(
                    input_ids, attention_mask=src_mask, output_hidden_states=True

However I had a bug where I was overrwriting teacher_encoder_outputs with the student encoder outputs after that, which I fixed now, but getting a different error:
(trying to share the link to the code changes in my fork/branch but not able to, not sure why)…

Traceback (most recent call last):
  File "distillation.py", line 338, in <module>
  File "distillation.py", line 329, in distill_main
    return ft_main(args, model=model)
  File "/home/sumithrab/transformers/examples/seq2seq/finetune.py", line 408, in main
    trainer: pl.Trainer = generic_train(
  File "/home/sumithrab/transformers/examples/lightning_base.py", line 383, in generic_train
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1003, in fit
    results = self.single_gpu_train(model)
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/pytorch_lightning/trainer/distrib_parts.py", line 186, in single_gpu_train
    results = self.run_pretrain_routine(model)
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1193, in run_pretrain_routine
    eval_results = self._evaluate(model,
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 293, in _evaluate
    output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 470, in evaluation_forward
    output = model.validation_step(*args)
  File "/home/sumithrab/transformers/examples/seq2seq/finetune.py", line 181, in validation_step
    return self._generative_step(batch)
  File "/home/sumithrab/transformers/examples/seq2seq/finetune.py", line 226, in _generative_step
    loss_tensors = self._step(batch)
  File "distillation.py", line 210, in _step
    outputs = self.teacher(
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sumithrab/transformers/src/transformers/modeling_t5.py", line 1201, in forward
    decoder_outputs = self.decoder(
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sumithrab/transformers/src/transformers/modeling_t5.py", line 757, in forward
    layer_outputs = layer_module(
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sumithrab/transformers/src/transformers/modeling_t5.py", line 547, in forward
    cross_attention_outputs = self.layer[1](
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sumithrab/transformers/src/transformers/modeling_t5.py", line 469, in forward
    attention_output = self.EncDecAttention(
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sumithrab/transformers/src/transformers/modeling_t5.py", line 388, in forward
    position_bias = position_bias + mask  # (bs, n_heads, qlen, klen)
RuntimeError: The size of tensor a (1024) must match the size of tensor b (14) at non-singleton dimension 3```

No clue, your on your own. Try to make a small broken unittest and debug it with pytest.

EG modify test_distill_t5

Thank you! :slight_smile: I fixed the issue and training is currently underway. I’ll keep you posted on whether it finishes successfully.

cool. Would be interested in seeing a PR of what you changed!

