Hello,
I’m trying to train a model with google’s ViT model and an extra layer on a doodle dataset using Hugging Face and PyTorch Lightning. After 5 hours of training, test accuracy increased from 0.0 to 0.75. However, when I saved and later loaded the model, test accuracy had fallen back to 0. This has happened 2-3 times. Could someone help me figure out what I am doing wrong?
I was using Google’s Quick Draw dataset downloaded from here: Quick, Draw! Doodle Recognition Challenge | Kaggle
Please find the necessary parts of my code here:
from torch import nn
from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule
from transformers import ViTModel, ViTConfig
from transformers import ViTForImageClassification, AdamW
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
model_name = 'google/vit-base-patch16-224-in21k'
#################################################################
#################### Model Class ################################
#################################################################
class DoodleTransformer(LightningModule):
def __init__(self, num_labels):
super(DoodleTransformer, self).__init__()
self.vit = ViTModel.from_pretrained(model_name)
self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
def forward(self, input_data):
outputs = self.vit(pixel_values=input_data)
logits = self.classifier(outputs.last_hidden_state[:, 0])
return logits
def common_step(self, batch, batch_idx):
pixel_values = batch['input_data']
labels = batch['labels']
logits = self(pixel_values)
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, labels)
predictions = logits.argmax(-1)
correct = (predictions == labels).sum().item()
accuracy = correct/pixel_values.shape[0]
return loss, accuracy
def training_step(self, batch, batch_idx):
pixel_values = batch['input_data']
labels = batch['labels']
logits = self(pixel_values)
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, labels)
predictions = logits.argmax(-1)
correct = (predictions == labels).sum().item()
accuracy = correct/pixel_values.shape[0]
# logs metrics for each training_step,
# and the average across the epoch
self.log("accuracy", accuracy, on_step=True, on_epoch=False, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
loss, accuracy = self.common_step(batch, batch_idx)
self.log("validation_loss", loss)
self.log("validation_accuracy", accuracy)
return loss
def test_step(self, batch, batch_idx):
loss, accuracy = self.common_step(batch, batch_idx)
self.log("test_loss", loss)
self.log("test_accuracy", accuracy)
return loss
def configure_optimizers(self):
# We could make the optimizer more fancy by adding a scheduler and specifying which parameters do
# not require weight_decay but just using AdamW out-of-the-box works fine
return AdamW(self.parameters(), lr=5e-5)
def train_dataloader(self):
return train_dataloader
def val_dataloader(self):
return val_dataloader
def test_dataloader(self):
return test_dataloader
model_path = "path/to/model.ckpt
model_path_pt = "path/to/state_dict.pt"
# I tried with both lines below as well as each one separately.
model = DoodleTransformer.load_from_checkpoint(model_path, num_labels=340)
model.load_state_dict(torch.load(model_path_pt))
#################################################################
######################## Trainer ################################
#################################################################
early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=1,
strict=False,
verbose=False,
mode='min'
)
checkpoint_callback = ModelCheckpoint(
monitor="validation_loss",
dirpath="./model",
filename="qd_model-3-{val_loss:.2f}",
every_n_epochs=1,
save_weights_only=True
)
trainer = Trainer(
gpus=1,
callbacks=[early_stop_callback, checkpoint_callback],
max_epochs=1
)
#################################################################
######################## Training ###############################
#################################################################
# Work as expected
trainer.fit(model)
#################################################################
######################## Testing ################################
#################################################################
# Accuracy of ~0.75 after training. But -0.00 if loaded from saved or new model.
test_results = trainer.test(model=model, dataloaders = test_dataloader)