Since further pre-training BERT might suffer from catastrophic forgetting, I am wondering, is it possible to use gradual unfreezing somehow with the Trainer module? I could do it with PyTorch and a classical training loop, but I was looking if you already have it implemented in the Trainer class.
or TrainingArguments class.
I think you can add SkipConnection to avoid Catastrophing forgetting in for ex BERT. It should be possible by subclassing and modifying main class, but I don’t know how to do that
class BertForMultilabelSequenceClassification(BertForSequenceClassification): ....