Helo All - I have created a flavour of Bert model with new embeddings by reusing existing Bert modules. My workflow is as follows:
- Modify run_mlm.py to add new embeddings in forward loop
- Run MLM training and save pre-trained model.
- Modify BertModel for classification (apporx 40 classes) by adapting embedding layer and forward.
- Load pre-trained model from step 2, fine-tune on classification task and save.
- Load for classification prediction.
The problem is that in step (4) I see exceptional performance on validation set. But when I load a model to make predictions (step 5) it performs extremely poorly and consistently selects only few classes. I made sure that the same code is run in validation loop as in prediction. I can also see saved new embedding layers when loading with pytorch load. Off-the shelf Bert model works well on the dataset, suggesting that the issue is not with DataLoader.
I am stack as I am not sure how best to debug this next - any suggestions would be most appreciated!
Many thanks for your help in advance!