Including classification heads in BERT saves

Hi, I am using the Trainer with BERT for a classification Task. I saved the model with this command:

model = transformers.AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
...
trainer.train()
_ = trainer.save_model("history/")

Since BERT does not use a classification head, this code adds one, if I remember right. However, it appears that when I load this model after training with

config = transformers.AutoConfig.from_pretrained("history")
model = transformers.AutoModelForSequenceClassification.from_config(config)
tokenizer = transformers.AutoTokenizer.from_pretrained("learning/checkpoint-31500")
pipe = transformers.TextClassificationPipeline(
	model=model,
	tokenizer=tokenizer,
	return_all_scores=True)

It seems to create a new head again, not use the one it originally trained. This makes loading and using a trained model not possible.
How can I save the head that Trainer trained for bert within the model so I can load and use it again?

Thanks in advance!

Since HF models inherit from torch.nn.module, you can save your model with it’s custom classifier head with: torch.save(model, 'model.pt') and load it with torch.load('model.pt'). Note that the loaded model can be input into the model argument of the trainer instance.

Now the second issue. How can you modify Trainer so that it correctly saves your model checkpoints? This might not be the most elegant solution but I was able to get it by overriding the _save method of Trainer like this:

class CustomTrainer(Trainer):
    def _save(self, output_dir, *args, **kwargs):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        torch.save(self.model, os.path.join(output_dir, 'model.pt'))

Then creating your trainer instance with:

# create Trainer instance
trainer = CustomTrainer(
    model=loaded_model,
    args=training_args,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
)

and trainer will save a file called model.pt in the checkpoint folder.
Note, you will not be able to load the model using ....from_pretrained() because it is not in the correct format but you will be able to load it with the torch load method above and supply that to your trainer instance. Hope this helps!