Loading Vision Transformer Model After Changing Its Classifier Head

I added more layers in the classifier head of the VIT model for image classification task like this:

model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k")

# Define a new sequential module for the classifier
new_classifier = nn.Sequential(
    nn.Linear(768, 256), # First layer with 256 neurons
    nn.ReLU(), # Activation function
    nn.BatchNorm1d(256),# Batch normalization for the first layer
    nn.Dropout(0.1), # dropout
    nn.Linear(256, 64), # Second layer with 64 neurons
    nn.ReLU(), # Activation function
    nn.BatchNorm1d(64), # Batch normalization for the second layer
    nn.Dropout(0.1), # dropout
    nn.Linear(64, 2) # Output layer with 2 neurons
model.classifier = new_classifier

Then I trained the model using Trainer() class from hugging face and saved the best model in a checkpoint.
when loading the trained model from this checkpoint I get this warning:

Some weights of the model checkpoint at /content/best_models_complex/checkpoint-564 were not used when initializing ViTForImageClassification: ['classifier.8.weight', 'classifier.6.running_var', 'classifier.2.running_mean', 'classifier.6.num_batches_tracked', 'classifier.6.weight', 'classifier.4.bias', 'classifier.2.running_var', 'classifier.2.num_batches_tracked', 'classifier.6.bias', 'classifier.8.bias', 'classifier.2.bias', 'classifier.2.weight', 'classifier.0.weight', 'classifier.0.bias', 'classifier.6.running_mean', 'classifier.4.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at /content/best_models_complex/checkpoint-564 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a downstream task to be able to use it for predictions and inference.

And the predictions of the model are very bad (probably because of the random weights of the added layers)
So my question is how to load the model from the checkpoint with the weights of the new layers not being ranodm?

I also wonder how Trainer() can load the best model at the end of the training correctly if loading the model from a checkpoint is not working properly.


To define a custom ViT for image classification model, it should be defined like here, with self.classifier being replaced by your custom classification head. Important is that your class should inherit from ViTPreTrainedModel in order to make it work with from_pretrained. This class can be imported from transformers.models.vit.modeling_vit.

1 Like