Hi there,
I am pretty newbie to the transformers (DL in general), and I am having some problems figuring out the following:
I have trained ‘tiny-bert’ following a knowledge distillation process from a finetuned ‘bert-base-cased’, where the goal was to do sentiment anlysis. Here is the code that shows this process:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, get_scheduler
from datasets import load_dataset
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
import copy
import numpy as np
# ========== 1. Configuración ==========
checkpoint = "bert-base-cased"
batch_size = 8
num_epochs = 10
learning_rate = 5e-5
distill_temp = 3.0
soft_target_loss_w = 0.5
nll_loss_weight = 0.5
reduced_hidden_dim = 1028
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ========== 2. Tokenización ==========
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
def tokenize_input(examples):
return tokenizer(examples['text'], truncation=True, padding=True, max_length=512)
# ========== 3. Dataset ==========
ds = load_dataset("stanfordnlp/imdb")
ds = ds.map(tokenize_input, batched=True)
ds = ds.remove_columns(['text'])
ds = ds.rename_column('label', 'labels')
# Creamos validación (10% del train)
ds = ds['train'].train_test_split(test_size=0.1)
train_dataset = ds['train']
eval_dataset = ds['test']
test_dataset = load_dataset("stanfordnlp/imdb", split="test")
test_dataset = test_dataset.map(tokenize_input, batched=True)
test_dataset = test_dataset.remove_columns(['text'])
test_dataset = test_dataset.rename_column('label', 'labels')
# ========== 4. Dataloaders ==========
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)
# ========== 5. Modelos ==========
model_teacher = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
model_teacher.load_state_dict(torch.load("models/bert_imbd_classifier.bin", map_location="cpu"))
model_teacher.to(device)
model_teacher.eval()
# ========== 6. Modelo Estudiante ==========
model_student = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny", num_labels=2)
model_student.to(device)
# ========== 7. Optimizer y scheduler ==========
optimizer = AdamW(model_student.parameters(), lr=learning_rate)
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
# ========== 8. Función de pérdida ==========
kd_loss_fn = nn.KLDivLoss(reduction="batchmean")
ce_loss_fn = nn.CrossEntropyLoss()
# ========== 9. Entrenamiento con distilación ==========
model_student.train()
for epoch in range(num_epochs):
total_loss = 0
model_student.train()
for batch in train_dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()
with torch.no_grad():
teacher_outputs = model_teacher(**batch)
soft_targets = nn.functional.softmax(teacher_outputs.logits / distill_temp, dim=-1)
student_outputs = model_student(**batch)
student_logits = student_outputs.logits
soft_preds = nn.functional.log_softmax(student_logits / distill_temp, dim=-1)
# Distillation loss
loss_kd = kd_loss_fn(soft_preds, soft_targets) * (distill_temp ** 2)
# CrossEntropy loss
loss_ce = ce_loss_fn(student_logits, batch['labels'])
loss = soft_target_loss_w * loss_kd + nll_loss_weight * loss_ce
loss.backward()
optimizer.step()
lr_scheduler.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_dataloader)
print(f"[Epoch {epoch+1}/{num_epochs}] Loss: {avg_loss:.4f}")
# ========== 10. Evaluación final ==========
model_student.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in test_dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model_student(**batch)
preds = torch.argmax(outputs.logits, dim=-1)
correct += (preds == batch["labels"]).sum().item()
total += batch["labels"].size(0)
accuracy = correct / total
print(f"Accuracy final del modelo estudiante: {accuracy:.4f}")
# ========== 11. Guardar modelo ==========
torch.save(model_student.state_dict(), "models/student_model.bin")
model_student.save_pretrained("student_model/")
I end up with good enough Acc (around 89%, which, for my use case, it is okay).
The problem is that, when I reload the model, the Acc over the same test dataset decreases significally, up to 50% (i.e, behave as it was never trained in the first place).
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, get_scheduler
from datasets import load_dataset
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
import copy
import numpy as np
# ======= 1. Configuración =======
checkpoint = "prajjwal1/bert-tiny"
batch_size = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ======= 2. Tokenización =======
def tokenize_input(examples):
return tokenizer(examples["text"], padding = True, truncation = True, max_length = 512)
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
# ======= 3. Carga del dataset =======
ds = load_dataset("stanfordnlp/imdb", split = "test")
ds = ds.map(tokenize_input, batched=True)
ds = ds.remove_columns(["text"])
ds = ds.rename_column("label", "labels")
test_dataset = ds
# ======= 4. Creamos el dataloader =======
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)
# ======= 5. Cargamos el modelo =======
model_pretrained = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels = 2)
model_pretrained.load_state_dict(torch.load("models/student_model.bin"))
model_pretrained.to(device)
model_pretrained.eval()
# ======= 6. Evaluamos el modelo preentrenado. En principio, 86% =======
correct = 0
total = 0
with torch.no_grad():
for batch in test_dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model_pretrained(**batch)
preds = torch.argmax(outputs.logits, dim = -1)
correct += (preds == batch["labels"]).sum().item()
total += batch["labels"].size(0)
acc = correct / total
print(f"Modelo preentrenado con acc final {acc:.4f}")
As I said, I am pretty newbie to DL, so if you find any other problem in the code not related to the question, I’d appreciate if you communicate it to me.
Thanks in advance!