List index out of range when saving

The following code snippet gives me a list index out of range error when it gets to the last line :pensive:
I’ve tried looking around and it seems like someone had a similar-ish issue that they solved by calling pretrained_bert_model.bert(inputs) for the outputs, but I took a look at what attributes I had available and there’s nothing similar I could do :persevere:

I’m using transformers version 4.11.3 and tf version 2.5.0

    albert = TFAlbertModel.from_pretrained('albert-base-v2', ignore_mismatched_sizes=True)
    albert.layers[0].trainable = False  # Freeze the base Albert transformer so we only train the classifier.

    input_ids = tf.keras.layers.Input(shape=(max_length,),
    input_attention = tf.keras.layers.Input(shape=(max_length,),

    # The item at index 0 of the output is the hidden state from the last layer
    # Then selecting [:,0,:] gives us the CLS token
    cls_token = albert([input_ids, input_attention])[0][:,0,:]

    classifier = tf.keras.layers.Dense(1,

    model = tf.keras.Model([input_ids, input_attention], classifier)

                  loss=tfa.losses.SigmoidFocalCrossEntropy(alpha=0.75), # imbalanced data
                  metrics=['Precision', 'Recall', tfa.metrics.F1Score(num_classes=1)])