hey @gabeorlanski, sorry for the slow reply.
if i understand correctly, the problem you’re facing is that after the first epoch, the encoder continues to be initialised from the facebook/bart-base
checkpoint - is that right?
as you suspect, i think this line might be the problem
self.encoder = AutoModel.from_pretrained(self.config.encoder_model)
because config.encoder_model
would always point to whatever value you defined in the config. i wonder whether the problem can be solved by replacing AutoModel.from_pretrained
with a dedicated model class like
self.encoder = BartModel(config)
this is closer to what you see in the source code for BertForSequenceClassification
and (i think) ensures the model is loaded from config.json
associated with each epoch’s checkpoint.