Regarding Training a Task Specific Knowledge Distillation model

I was referring to this code:

From @philschmid

I could follow most of the code, but had few doubts. Please help me to clarify these doubts.

In this code below:

class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.teacher = teacher_model
        
        # place teacher on same device as student
        self._move_model_to_device(self.teacher,self.model.device)
        
        self.teacher.eval()

When I take fine-tuned teacher model it is never fine-tuned in the process of Task Specific Distillation training, as in line self.teacher.eval() mentioned in the code.? Only the output of teacher model is considered for loss calculations.

I couldn’t follow this line self._move_model_to_device(self.teacher,self.model.device). What it is actually doing?

In Task Specific Distillation training, I am fine tuning my student model, but in the DistillationTrainer I pass both models. Where it’s making sure that only student model weights are learned and not the teacher?

trainer = DistillationTrainer(
    student_model,
    training_args,
    teacher_model=teacher_model,
    train_dataset=train_data,
    eval_dataset=val_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    
)

As far as I can tell, the student model is the one being passed to the Trainer's model property, and when trainer.train() gets called the Trainer will only look into its model to adjust the weights. As you pointed out, the teacher model is being set to eval mode, and it’s only being used in the overridden compute_loss function. (More info about compute_loss here: Trainer) The DistillationTrainer class is just a custom subclass, and the teacher_model won’t actually get passed into the Trainer.

I think self._move_model_to_device(self.teacher,self.model.device) just sets the teacher model to use GPU if it’s available (or at least puts it on whatever device the student model is on), since the Trainer class already does that automatically for the student model when it’s passed in: transformers/trainer.py at c4ad38e5ac69e6d96116f39df789a2369dd33c21 · huggingface/transformers · GitHub

Hope this helps!

2 Likes

Yes, it helps. Few more clarification on your answer!

But in the class DistillationTrainer(Trainer): both models are passed, if you see here

trainer = DistillationTrainer(
    **student_model,**
    training_args,
    **teacher_model=teacher_model,**
    train_dataset=train_data,
    eval_dataset=val_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    
)

Why you say that << the student model is the one being passed to the Trainer 's model property,>>

And teacher model is just for purpose to calculate loss?

And one more basic doubt is:

In this code:

 loss_function = nn.KLDivLoss(reduction="batchmean")
        loss_logits = (loss_function
                       (
            F.log_softmax(outputs_student.logits / self.args.temperature, dim=-1),
            F.softmax(outputs_teacher.logits / self.args.temperature, dim=-1)
                       ) * (self.args.temperature ** 2)
                      )

Why are we using log_softmax for student and softmax for teacher?
Also what is the purpose of KLDivLoss? What it is doing?

The DistillationTrainer is a new custom class that’s being created in your notebook, which is subclassing the Trainer class (which is from Hugging Face’s transformers). So even though you pass both the student_model and teacher_model to the DistillationTrainer, note that this section of the code:

class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        # etc...

means that the teacher model is being saved in the DistillationTrainer as self.teacher, while the student model is being passed up through to the Trainer’s init function as the first parameter when you run:

trainer = DistillationTrainer(
    student_model, # This is a positional argument, captured in *args and passed to super().__init__
    training_args,
    teacher_model=teacher_model,
    train_dataset=train_data,
    eval_dataset=val_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

If we look at where self.teacher gets used, it’s actually only called in the compute_loss() function. (The Trainer class doesn’t use a self.teacher property)

Are you familiar with subclasses and super in Python? If not, reading into that might help a bit!

As for the other questions, they’re out of my expertise but this GitHub issue seems to explain it pretty well! (kd loss · Issue #2 · haitongli/knowledge-distillation-pytorch · GitHub) Here’s a Reddit thread specifically about KLDivLoss, which seems to be commonly used in knowledge distillation: Reddit - Dive into anything

This is pretty neat, I’ve learned a lot through researching for your question :hugs:

2 Likes

I learned a lot through your answers!! Thanks.

1 Like

Hi @NimaBoscarino , Why in the paper they mention that <<In our experiments, we have observed that dis- tilled models do not work well when distilled to a different model type. Therefore, we restricted our setup to avoid distilling RoBERTa model to BERT or vice versa. The major difference between the two model groups is the input token (sub-word) em- bedding. We think that different input embedding spaces result in different output embedding spaces, and knowledge transfer with different spaces does not work well>>

At the end, we are only taking KLDiversion Loss between 2 dimensional logit vector, right ? So can’t we use Roberta XLM large model as teacher and Bert base as student. Both have different tokenization. Or we can still push our code to train whatever we want, but technically its not wise to do this?

Oooh, I don’t know enough about that to comment on it :grimacing: I think that what they mean in the paper is that you could use those as teacher + student, but you might not be super impressed with the results. That’s my intuition though, so I might be wrong!

1 Like

can anyone help to perform task specific knowledge distillation on NER task.