Heads up: Transformer layers + Functional API -> missing trainable weights

tl’dr: For anyone who has used the functional api with transformer layers, it might be worth running

for i, var in enumerate(model.trainable_weights):
    print(model.trainable_weights[i].name)

so see if all your weights are there.


I am wondering if this is a huge bug/flaw in Keras; using transformer layers (and maybe all custom layers) with the functional API results in missing weights in the trainable_variables. Those weights are not in the 'non_trainable_variables` either.

But if those weights aren’t in trainable_variables they are essential frozen, since it is only those weights that receive gradient updates.

The bug can be seen in this Colab gist

https://colab.research.google.com/gist/Santosh-Gupta/766a27c1500a330cba6f479805dad27d/missingtrainablevarsinference.ipynb

This gist uses the transformers library to create the models so its easy to see the bug. For an in depth look, the colab gist below creates all the transformer layers from scratch (based on HF’s code)

https://colab.research.google.com/gist/Santosh-Gupta/ec9f7c8a189a8d99e73d96fbe728aef8/model_weight_debug_scratch_public_inference.ipynb

As you can see in the notebooks, a workaround is to create models using keras subclassing to create the models instead; model subclassing results in all the weights appearing in trainable_variables.

However, I have models that I already trained with the functional API. I’ve been looking at this for about a month, as far as I can tell, any Keras model using custom sublayers and the functional API is prone to this. And I am wondering if any other the models I trained with custom layers and the functional API that may have compromised training

I put up a Github issue 24 days ago, but I can’t tell if this is something being worked on.

Edit: I can only put two links in a post, but you can just search the Tensorflow issues for “Keras layer weights/sublayers getting deleted when creating a model with them”