Pytorch BERT model not converging

I’m currently working in research in transfer learning, and I’m trying to use bert-base-cased as a pretrained baseline, wrapped in a Pytorch model with dropout and a linear layer. I largely follow the recommendations of BERT and train using AdamW optimizer and a scheduler.

My issue is, I can train on one task no problem, with the BERT recommended parameters of LR=2e-5, batch size=32, epochs=2. I use a cross entropy loss with class weighting to address a label imbalance issue.

My issue is when I save this model, then reload the state as a base for further downstream training on a different classification task, it never finishes a single epoch. What’s also strange is I get 100% GPU utilisation, which doesn’t change (not the same on the baseline models).

def train(self,
          minibatches: -> Dict:
    self.model = self.model.train()

    metrics = {"n_correct": 0,
               "losses": []}

    for batch in minibatches:
        texts, targets = batch
        targets =
        encoded_input = BERTPreprocessor.encode(texts, self.tokenizer)

        for tensor in encoded_input:
            encoded_input[tensor] = (encoded_input[tensor]

        logits, *_ = self.model(**{
            "input_ids": encoded_input["input_ids"],
            "attention_mask": encoded_input["attention_mask"],

        loss = self.loss_fn(logits, targets)

        _, preds = torch.max(logits, dim=1)
        metrics["n_correct"] += torch.sum(preds == targets)

    return metrics

I’ve attached my training process above and I’m happy to provide any further information, but as I’m fairly new to the field, I’m at a loss with how to address this.

To add, the weighting scheme I use for the weight parameter in the loss function is num_minority_class/class_n for the negative and positive classes (binary classification).


[I am not an expert, but I have saved and reloaded Bert models].

What commands are you using to save and reload your model? Are you saving just the BERT weights, or your custom dropout and linear layers too?

Do you get any error messages on reloading? Do you get any error messages when you try to run your second classification task (or does it just go on in an infinite loop)?

Have you tried running your second task immediately after your first task (without doing a save and reload)? Does it work?

[Silly questions] How do you know your first task has “worked”? Are you using BertModel, and have you considered using BertForSequenceClassification? When you train, are you updating the BERT weights, or have you frozen them?