0% accuracy when finetuning from certain models. [CLS] token embeddings not learned

When fine-tuning DistilBertForSequenceClassification something seems to go wrong when initializing certain from models using from_pretrained, that prevents the [CLS] token embedding from being learned and therefore prevents classification from producing sensible results.

Below is a custom training script, showing a DistilBertForSequenceClassification being initialized from different models weights as well as being newly created.
It is overfitted on a single batch of data.

Expected Behaviour:

  • An accuracy of 100% after overfitting
  • Good accuracy results when performing standard training & eval

Observed Behaviour

  • This reliably does not work for certain models (see script for examples). Then, the [CLS] token embedding` will be the same for whatever input and therefore classification will not work
  • I have not found any distinguishing characteristics (config f.e.) of the models for which it does not work
  • The same issue occurs when using the official classification script (https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification)

I鈥檇 be happy about any input and to hear if someone experienced the same issue.

import datasets
import evaluate
import torch
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DistilBertConfig,
    DistilBertForSequenceClassification
)

model_name = "johannes-garstenauer/distilbert-heaps-masked" # This doesn't work
#model_name = "distilbert-base-uncased" # This works
#model_name="AyoubChLin/distilbert_cnn_news" # This works
do_finetune = True

print("\n")
print(f"Finetuning {do_finetune}")
if do_finetune:
    print(f"Model: {model_name}")
print("\n")

tokenizer = AutoTokenizer.from_pretrained(model_name)


def preprocess_function(examples):
    return tokenizer(examples["struct"], truncation=True, max_length=512)


ds_name = "johannes-garstenauer/balanced_factor_3_structs_reduced_5labelled_large"
raw_dataset = datasets.load_dataset(ds_name, split="train[:1%]")

tokenized_datasets = raw_dataset.map(preprocess_function, batched=True)
tokenized_datasets = tokenized_datasets.train_test_split(test_size=0.05)

if do_finetune:
    print("Finetuning")
    # When using model 'AyoubChLin/distilbert_cnn_news', might have to adapt num_labels=6 to avoid error
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=5)
else:
    config = DistilBertConfig(output_hidden_states=True, num_labels=5)
    model = DistilBertForSequenceClassification(config)
print(model.config)

args = TrainingArguments(
    f"distilbert-finetuned",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=3,
)

metric = evaluate.load("accuracy")


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = torch.argmax(predictions, dim=-1)
    return metric.compute(predictions=predictions, references=labels)


trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    compute_metrics=compute_metrics,
    tokenizer=tokenizer
)

# Overfitting the model on one batch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

for batch in trainer.get_train_dataloader():
    break

batch = {k: v.to(device) for k, v in batch.items()}
trainer.create_optimizer()

for _ in tqdm(range(40)):
    outputs = trainer.model(**batch)
    loss = outputs.loss
    loss.backward()
    trainer.optimizer.step()
    trainer.optimizer.zero_grad()

with torch.no_grad():
    outputs = trainer.model(**batch)
preds = outputs.logits
labels = batch["labels"]
print(compute_metrics((preds, labels)))
print(batch['labels'])

The solution to this bug, was with the Tokenizer. The model and the tokenizer id for the [PAD] token did not match.
Find the solution described here: DistilBertForSequenceClassification 0% accuracy when fine-tuning (using from_pretrained()) 路 Issue #26034 路 huggingface/transformers 路 GitHub