List index out of range when saving

The following code snippet gives me a list index out of range error when it gets to the last line :pensive:
I’ve tried looking around and it seems like someone had a similar-ish issue that they solved by calling pretrained_bert_model.bert(inputs) for the outputs, but I took a look at what attributes I had available and there’s nothing similar I could do :persevere:

I’m using transformers version 4.11.3 and tf version 2.5.0

    albert = TFAlbertModel.from_pretrained('albert-base-v2', ignore_mismatched_sizes=True)
    albert.layers[0].trainable = False  # Freeze the base Albert transformer so we only train the classifier.

    input_ids = tf.keras.layers.Input(shape=(max_length,),
                                      dtype='int32',
                                      name="input_ids")
    input_attention = tf.keras.layers.Input(shape=(max_length,),
                                            dtype='int32',
                                            name="attention_mask")

    # The item at index 0 of the output is the hidden state from the last layer
    # Then selecting [:,0,:] gives us the CLS token
    cls_token = albert([input_ids, input_attention])[0][:,0,:]

    classifier = tf.keras.layers.Dense(1,
                                       activation='sigmoid',
                                       name="labels"
                                       )(cls_token)

    model = tf.keras.Model([input_ids, input_attention], classifier)

    model.compile(tf.keras.optimizers.Adam(lr=3e-4),
                  loss=tfa.losses.SigmoidFocalCrossEntropy(alpha=0.75), # imbalanced data
                  metrics=['Precision', 'Recall', tfa.metrics.F1Score(num_classes=1)])

    model.save(save_loc)