Saved pytorch lightning / hugging face model is not loading properly

Hello,

I’m trying to train a model with google’s ViT model and an extra layer on a doodle dataset using Hugging Face and PyTorch Lightning. After 5 hours of training, test accuracy increased from 0.0 to 0.75. However, when I saved and later loaded the model, test accuracy had fallen back to 0. This has happened 2-3 times. Could someone help me figure out what I am doing wrong?

I was using Google’s Quick Draw dataset downloaded from here: Quick, Draw! Doodle Recognition Challenge | Kaggle

Please find the necessary parts of my code here:

from torch import nn
from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule
from transformers import ViTModel, ViTConfig
from transformers import ViTForImageClassification, AdamW
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

model_name = 'google/vit-base-patch16-224-in21k'

#################################################################
#################### Model Class ################################
#################################################################

class DoodleTransformer(LightningModule):
    def __init__(self, num_labels):
        super(DoodleTransformer, self).__init__()
        
        self.vit = ViTModel.from_pretrained(model_name)
        self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)

    def forward(self, input_data):
        outputs = self.vit(pixel_values=input_data)
        logits = self.classifier(outputs.last_hidden_state[:, 0])
        
        return logits
    
    def common_step(self, batch, batch_idx):
        pixel_values = batch['input_data']
        labels = batch['labels']
        logits = self(pixel_values)

        criterion = nn.CrossEntropyLoss()
        loss = criterion(logits, labels)
        predictions = logits.argmax(-1)
        correct = (predictions == labels).sum().item()
        accuracy = correct/pixel_values.shape[0]

        return loss, accuracy
      
    def training_step(self, batch, batch_idx):
        pixel_values = batch['input_data']
        labels = batch['labels']
        logits = self(pixel_values)

        criterion = nn.CrossEntropyLoss()
        loss = criterion(logits, labels)
        predictions = logits.argmax(-1)
        correct = (predictions == labels).sum().item()
        accuracy = correct/pixel_values.shape[0]
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("accuracy", accuracy, on_step=True, on_epoch=False, prog_bar=True, logger=True)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)
        
        self.log("validation_loss", loss)
        self.log("validation_accuracy", accuracy)

        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)
        
        self.log("test_loss", loss)
        self.log("test_accuracy", accuracy)

        return loss

    def configure_optimizers(self):
        # We could make the optimizer more fancy by adding a scheduler and specifying which parameters do
        # not require weight_decay but just using AdamW out-of-the-box works fine
        return AdamW(self.parameters(), lr=5e-5)

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return val_dataloader

    def test_dataloader(self):
        return test_dataloader
    
 
model_path = "path/to/model.ckpt
model_path_pt = "path/to/state_dict.pt"

# I tried with both lines below as well as each one separately.
model = DoodleTransformer.load_from_checkpoint(model_path, num_labels=340)
model.load_state_dict(torch.load(model_path_pt))



#################################################################
######################## Trainer ################################
#################################################################

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=1,
    strict=False,
    verbose=False,
    mode='min'
)

checkpoint_callback = ModelCheckpoint(
    monitor="validation_loss",
    dirpath="./model",
    filename="qd_model-3-{val_loss:.2f}",
    every_n_epochs=1,
    save_weights_only=True
)

trainer = Trainer(
    gpus=1,
    callbacks=[early_stop_callback, checkpoint_callback],
    max_epochs=1
)

#################################################################
######################## Training ###############################
#################################################################

# Work as expected
trainer.fit(model)

#################################################################
######################## Testing ################################
#################################################################

# Accuracy of ~0.75 after training. But -0.00 if loaded from saved or new model.
test_results = trainer.test(model=model, dataloaders = test_dataloader)

hello, can you try doing this

test_results = trainer.test(model=model, dataloaders = test_dataloader, ckpt_path="<your_model_checkpoint>")