Replace weights in TFBertModel

I have a multilabel model built from the TFAutoModelForSequenceClassification in which I took the TFBertMainLayer (in the code below it is the bert = transformer_model.layers[0]) on top of which I added a Dropout and a Dense layer.
After compiling and fitting model I saved the model weights as an h5 file and saved the model architecture in a json file (using model.to_json() in Keras).

        bert = transformer_model.layers[0]
        input_ids = tf.keras.layers.Input(shape=(input_dim,), name='input_ids', dtype='int32')
        attention_mask = tf.keras.layers.Input(shape=(input_dim,), name='attention_mask', dtype='int32')
        inputs = {'input_ids': input_ids, 'attention_mask': attention_mask}
        bert_model = bert(input_ids, attention_mask)[1]
        X = tf.keras.layers.Dropout(transformer_model.config.hidden_dropout_prob, name='pooled_output', trainable=True)(bert_model)
        X = tf.keras.layers.Dense(units=num_labels, activation='sigmoid', name='dense', trainable=True)(X)
        model = tf.keras.Model(inputs=inputs, outputs=X)

I want to visualize the attention weights of the model and came across However, it doesn’t look like it works well with models not in based on pytorch objects.
A possible solution I thought about includes the following steps:

  1. Use theTFBertModel: initialize the TFBertModel and replace the weights of the TFBertMainLayer with the weights of my trained model. Namely, I tried doing something like this

tf_bert_model = TFBertModel.from_pretrained('bert-base-uncased')

But it doesn’t seem to work and I am not able to replace the weights.

  1. Then if I can get step #1 to work I thought to save the tf_bert_model using tf_bert_model.save_pretrained() and load it to the pytorch class BertModel which should then enable me to work with bertviz.

Any ideas how I can replace the weights to make step #1 work? Or another idea to get around the issue so I can get bertviz working with my keras model?

Any help will be greatly appreciated.
Ayala Allon