Accuracy decreasing after saving/reloading my model

Hi there,
I am pretty newbie to the transformers (DL in general), and I am having some problems figuring out the following:
I have trained ‘tiny-bert’ following a knowledge distillation process from a finetuned ‘bert-base-cased’, where the goal was to do sentiment anlysis. Here is the code that shows this process:

from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, get_scheduler
from datasets import load_dataset
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
import copy
import numpy as np

# ========== 1. Configuración ==========
checkpoint = "bert-base-cased"
batch_size = 8
num_epochs = 10
learning_rate = 5e-5
distill_temp = 3.0
soft_target_loss_w = 0.5
nll_loss_weight = 0.5
reduced_hidden_dim = 1028

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ========== 2. Tokenización ==========
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

def tokenize_input(examples):
    return tokenizer(examples['text'], truncation=True, padding=True, max_length=512)

# ========== 3. Dataset ==========
ds = load_dataset("stanfordnlp/imdb")
ds = ds.map(tokenize_input, batched=True)
ds = ds.remove_columns(['text'])
ds = ds.rename_column('label', 'labels')

# Creamos validación (10% del train)
ds = ds['train'].train_test_split(test_size=0.1)
train_dataset = ds['train']
eval_dataset = ds['test']
test_dataset = load_dataset("stanfordnlp/imdb", split="test")
test_dataset = test_dataset.map(tokenize_input, batched=True)
test_dataset = test_dataset.remove_columns(['text'])
test_dataset = test_dataset.rename_column('label', 'labels')

# ========== 4. Dataloaders ==========
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)

# ========== 5. Modelos ==========
model_teacher = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
model_teacher.load_state_dict(torch.load("models/bert_imbd_classifier.bin", map_location="cpu"))
model_teacher.to(device)
model_teacher.eval()

# ========== 6. Modelo Estudiante ==========
model_student = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny", num_labels=2)

model_student.to(device)

# ========== 7. Optimizer y scheduler ==========
optimizer = AdamW(model_student.parameters(), lr=learning_rate)
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

# ========== 8. Función de pérdida ==========
kd_loss_fn = nn.KLDivLoss(reduction="batchmean")
ce_loss_fn = nn.CrossEntropyLoss()

# ========== 9. Entrenamiento con distilación ==========
model_student.train()
for epoch in range(num_epochs):
    total_loss = 0
    model_student.train()

    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        optimizer.zero_grad()

        with torch.no_grad():
            teacher_outputs = model_teacher(**batch)
            soft_targets = nn.functional.softmax(teacher_outputs.logits / distill_temp, dim=-1)

        student_outputs = model_student(**batch)
        student_logits = student_outputs.logits
        soft_preds = nn.functional.log_softmax(student_logits / distill_temp, dim=-1)

        # Distillation loss
        loss_kd = kd_loss_fn(soft_preds, soft_targets) * (distill_temp ** 2)

        # CrossEntropy loss
        loss_ce = ce_loss_fn(student_logits, batch['labels'])

        loss = soft_target_loss_w * loss_kd + nll_loss_weight * loss_ce
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_dataloader)
    print(f"[Epoch {epoch+1}/{num_epochs}] Loss: {avg_loss:.4f}")

# ========== 10. Evaluación final ==========
model_student.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in test_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model_student(**batch)
        preds = torch.argmax(outputs.logits, dim=-1)
        correct += (preds == batch["labels"]).sum().item()
        total += batch["labels"].size(0)

accuracy = correct / total
print(f"Accuracy final del modelo estudiante: {accuracy:.4f}")

# ========== 11. Guardar modelo ==========
torch.save(model_student.state_dict(), "models/student_model.bin")

model_student.save_pretrained("student_model/")

I end up with good enough Acc (around 89%, which, for my use case, it is okay).

The problem is that, when I reload the model, the Acc over the same test dataset decreases significally, up to 50% (i.e, behave as it was never trained in the first place).

from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, get_scheduler
from datasets import load_dataset
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
import copy
import numpy as np
        
# ======= 1. Configuración =======
checkpoint = "prajjwal1/bert-tiny"
batch_size = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ======= 2. Tokenización =======
def tokenize_input(examples):
    return tokenizer(examples["text"], padding = True, truncation = True, max_length = 512)

if __name__ == "__main__":
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    # ======= 3. Carga del dataset =======
    ds = load_dataset("stanfordnlp/imdb", split = "test")
    ds = ds.map(tokenize_input, batched=True)
    ds = ds.remove_columns(["text"])
    ds = ds.rename_column("label", "labels")
    test_dataset = ds

    # ======= 4. Creamos el dataloader =======
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)

    # ======= 5. Cargamos el modelo =======
    model_pretrained = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels = 2)
    model_pretrained.load_state_dict(torch.load("models/student_model.bin"))
    model_pretrained.to(device)
    model_pretrained.eval()

    # ======= 6. Evaluamos el modelo preentrenado. En principio, 86% =======
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in test_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model_pretrained(**batch)
            preds = torch.argmax(outputs.logits, dim = -1)
            correct += (preds == batch["labels"]).sum().item()
            total += batch["labels"].size(0)

    acc = correct / total
    print(f"Modelo preentrenado con acc final {acc:.4f}")


As I said, I am pretty newbie to DL, so if you find any other problem in the code not related to the question, I’d appreciate if you communicate it to me.

Thanks in advance! :blush:

1 Like

I think you forgot to save and load the tokenizer.

# after finishing training…
model_student.eval()                                   
model_student.save_pretrained("student_model/")         # saves config.json + pytorch_model.bin
tokenizer.save_pretrained("student_model/")             # saves tokenizer.json + vocab files

# when reloading...
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("student_model/")
tokenizer = AutoTokenizer.from_pretrained("student_model/")