Very poor model performance post-training

Helo All - I have created a flavour of Bert model with new embeddings by reusing existing Bert modules. My workflow is as follows:

  1. Modify run_mlm.py to add new embeddings in forward loop
  2. Run MLM training and save pre-trained model.
  3. Modify BertModel for classification (apporx 40 classes) by adapting embedding layer and forward.
  4. Load pre-trained model from step 2, fine-tune on classification task and save.
  5. Load for classification prediction.

The problem is that in step (4) I see exceptional performance on validation set. But when I load a model to make predictions (step 5) it performs extremely poorly and consistently selects only few classes. I made sure that the same code is run in validation loop as in prediction. I can also see saved new embedding layers when loading with pytorch load. Off-the shelf Bert model works well on the dataset, suggesting that the issue is not with DataLoader.

I am stack as I am not sure how best to debug this next - any suggestions would be most appreciated!
Many thanks for your help in advance!