Knowledge Distillation of SentenceTransformer - problems making it work

Hi everyone,

I’ve also tried to raise this on github but since I’m not getting any repsonses there, I thought I’d try it here. I hope that’s cool.
I’ve fine-tuned a sentence-transformer model and it’s performing very well on my task. It is however pretty slow and I found this very helpfull guide and code to use distillation: Model Distillation — Sentence-Transformers documentation

My issue is that after distillation the cosine-similarities are completely off. Both sentences that should be similar and sentences that shouldn’t be similar are considered very similar by the distilled model.

The process seems very straightforward and most of it is abstracted away, I feel like I must be missing something obvious :confused:
I’m using the same data I used to fine-tune the model originaly, so I don’t think there’s anything wrong with my data.

Here’s my script (slightly adapted version of the script here https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/distillation/model_distillation.py):


from torch.utils.data import DataLoader
from sentence_transformers import models, losses, evaluation
from sentence_transformers import LoggingHandler, SentenceTransformer, InputExample
from sentence_transformers.datasets import ParallelSentencesDataset
from sklearn.decomposition import PCA
import logging
import random
import torch
import csv

logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])


# Teacher Model: Model we want to distill to a smaller model
teacher_model_name = 'my-semantic'
teacher_model = SentenceTransformer(teacher_model_name)

output_path = "my-semantic-distilled"

student_model = SentenceTransformer(teacher_model_name)

auto_model = student_model._first_module().auto_model

# Which layers to keep from the teacher model. We equally spread the layers to keep over the original teacher
#layers_to_keep = [5]
#layers_to_keep = [3, 7]
#layers_to_keep = [3, 7, 11]
layers_to_keep = [1, 4, 7, 10]
#layers_to_keep = [0, 2, 4, 6, 8, 10]
#layers_to_keep = [0, 1, 3, 4, 6, 7, 9, 10]

logging.info("Remove layers from student. Only keep these layers: {}".format(layers_to_keep))
new_layers = torch.nn.ModuleList([layer_module for i, layer_module in enumerate(auto_model.encoder.layer) if i in layers_to_keep])
auto_model.encoder.layer = new_layers
auto_model.config.num_hidden_layers = len(layers_to_keep)

inference_batch_size = 64
train_batch_size = 64

train_sentences = []
dev_sentences = []

csv.field_size_limit(1000000)

logging.info("Collecting training data")
with open("../data/sample_pairs.csv", encoding="utf-8") as f:
    csv_reader = csv.reader(f)    
    for cnt, sample in enumerate(csv_reader):
        train_sentences.append(sample[0])

logging.info("Collecting dev data")
with open("../data/dev_sample_pairs.csv", encoding="utf-8") as f:
    csv_reader = csv.reader(f)    
    for cnt, sample in enumerate(csv_reader):
        dev_sentences.append(sample[0])

dev_samples = []
with open("../data/sample_triplets.csv", encoding="utf-8") as f:
    csv_reader = csv.reader(f)    
    for cnt, sample in enumerate(csv_reader):
        dev_samples.append(InputExample(texts=[sample[0], sample[1]], label=int(sample[2])))

random.shuffle(dev_samples)
dev_samples = dev_samples[0:100]

train_data = ParallelSentencesDataset(student_model=student_model, teacher_model=teacher_model, batch_size=inference_batch_size, use_embedding_cache=False)
train_data.add_dataset([[sent] for sent in train_sentences], max_sentence_length=512)

train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size)
train_loss = losses.MSELoss(model=student_model)

print("Teacher:")
print(teacher_model)
print("Student:")
print(student_model)
print("Sample:")
print(train_data[0])

logging.info("Creating evaluators")
dev_evaluator_mse = evaluation.MSEEvaluator(dev_sentences, dev_sentences, teacher_model=teacher_model)
dev_evaluator_sts = evaluation.EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev')

logging.info("Eval teacher model")
dev_evaluator_sts(teacher_model)

epochs = 20
warmup_steps = int(len(train_dataloader) * epochs * 0.1)
logging.info("Starting to fit model")
student_model.fit(train_objectives=[(train_dataloader, train_loss)],
                  evaluator=evaluation.SequentialEvaluator([dev_evaluator_sts, dev_evaluator_mse]),
                  epochs=epochs,
                  warmup_steps=warmup_steps,
                  evaluation_steps=7000,
                  output_path=output_path,
                  save_best_model=True,
                  show_progress_bar=False,
                  optimizer_params={'lr': 2e-5, 'eps': 1e-6, 'correct_bias': False})

Here’s the console output of a few epochs:

2022-04-08 08:49:22 - Remove layers from student. Only keep these layers: [1, 4, 7, 10]
2022-04-08 08:49:22 - Collecting training data
2022-04-08 08:49:23 - Collecting dev data
Teacher:
SentenceTransformer(
(0): Transformer({‘max_seq_length’: 512, ‘do_lower_case’: False}) with Transformer model: BertModel
(1): Pooling({‘word_embedding_dimension’: 1024, ‘pooling_mode_cls_token’: False, ‘pooling_mode_mean_tokens’: True, ‘pooling_mode_max_tokens’: False, ‘pooling_mode_mean_sqrt_len_tokens’: False})
)
Student:
SentenceTransformer(
(0): Transformer({‘max_seq_length’: 512, ‘do_lower_case’: False}) with Transformer model: BertModel
(1): Pooling({‘word_embedding_dimension’: 1024, ‘pooling_mode_cls_token’: False, ‘pooling_mode_mean_tokens’: True, ‘pooling_mode_max_tokens’: False, ‘pooling_mode_mean_sqrt_len_tokens’: False})
)
Sample:
label: [ 0.48059258 0.1424962 -0.15087782 … 0.04953826 1.0277015
-0.46856898], texts: Das ist ein Beispiel, der echte Text ist leider geschützt).

2022-04-08 08:49:55 - Creating evaluators

2022-04-08 10:52:35 - EmbeddingSimilarityEvaluator: Evaluating the model on sts-dev dataset after epoch 0:

2022-04-08 10:52:52 - Cosine-Similarity : Pearson: 0.2111 Spearman: 0.2317
2022-04-08 10:52:52 - Manhattan-Distance: Pearson: 0.2497 Spearman: 0.2317
2022-04-08 10:52:52 - Euclidean-Distance: Pearson: 0.2417 Spearman: 0.2429
2022-04-08 10:52:52 - Dot-Product-Similarity: Pearson: -0.0054 Spearman: -0.0295

2022-04-08 10:55:05 - MSE evaluation (lower = better) on dataset after epoch 0:
2022-04-08 10:55:05 - MSE (*100): 65.614074
2022-04-08 10:55:05 - Save model to my-semantic-distilled

2022-04-08 11:09:36 - EmbeddingSimilarityEvaluator: Evaluating the model on sts-dev dataset after epoch 1:

2022-04-08 11:09:52 - Cosine-Similarity : Pearson: 0.2226 Spearman: 0.2766
2022-04-08 11:09:52 - Manhattan-Distance: Pearson: 0.2924 Spearman: 0.2934
2022-04-08 11:09:52 - Euclidean-Distance: Pearson: 0.2528 Spearman: 0.2682
2022-04-08 11:09:52 - Dot-Product-Similarity: Pearson: 0.1022 Spearman: 0.0997

2022-04-08 11:12:00 - MSE evaluation (lower = better) on dataset after epoch 1:
2022-04-08 11:12:00 - MSE (*100): 53.866756
2022-04-08 11:12:00 - Save model to my-semantic-distilled

2022-04-08 11:26:27 - EmbeddingSimilarityEvaluator: Evaluating the model on sts-dev dataset after epoch 2:

2022-04-08 11:26:44 - Cosine-Similarity : Pearson: 0.1309 Spearman: 0.2260
2022-04-08 11:26:44 - Manhattan-Distance: Pearson: 0.2826 Spearman: 0.2794
2022-04-08 11:26:44 - Euclidean-Distance: Pearson: 0.1996 Spearman: 0.2513
2022-04-08 11:26:44 - Dot-Product-Similarity: Pearson: -0.0222 Spearman: -0.0098

2022-04-08 11:28:53 - MSE evaluation (lower = better) on dataset after epoch 2:
2022-04-08 11:28:53 - MSE (*100): 41.433170
2022-04-08 11:28:53 - Save model to my-semantic-distilled

2022-04-08 11:43:48 - EmbeddingSimilarityEvaluator: Evaluating the model on sts-dev dataset after epoch 3:

2022-04-08 11:44:05 - Cosine-Similarity : Pearson: 0.1046 Spearman: 0.2148
2022-04-08 11:44:05 - Manhattan-Distance: Pearson: 0.2685 Spearman: 0.2710
2022-04-08 11:44:05 - Euclidean-Distance: Pearson: 0.1869 Spearman: 0.2541
2022-04-08 11:44:05 - Dot-Product-Similarity: Pearson: -0.1361 Spearman: -0.1137

2022-04-08 11:46:18 - MSE evaluation (lower = better) on dataset after epoch 3:

2022-04-08 11:46:18 - MSE (*100): 33.955178
2022-04-08 11:46:18 - Save model to my-semantic-distilled

As you can see, it’s not going great. The MSE is going down but the SentenceSimilarity is actually getting worse at the same time (this also confuses me).

Here’s the evaluation of the teacher btw:

2022-04-09 09:02:48 - Eval teacher model
2022-04-09 09:02:48 - EmbeddingSimilarityEvaluator: Evaluating the model on sts-dev dataset:
2022-04-09 09:05:54 - Cosine-Similarity : Pearson: 0.9558 Spearman: 0.8392
2022-04-09 09:05:54 - Manhattan-Distance: Pearson: 0.9671 Spearman: 0.8399
2022-04-09 09:05:54 - Euclidean-Distance: Pearson: 0.9672 Spearman: 0.8392
2022-04-09 09:05:54 - Dot-Product-Similarity: Pearson: 0.9533 Spearman: 0.8399

Any pointers would be greatly appreciated!