[Keras] Fine-Tune Vision Transformer Model?

I’m looking keras approach to freeze and unfreeze the vision transformer model. For example, in huggingface vision model, I can do as follows

from transformers import SegformerFeatureExtractor
from transformers import TFSegformerForImageClassification as tfseg

tf_huggingface_module = tfseg.from_pretrained(
     'nvidia/mit-b0'
)
tf_huggingface_module.trainable = False

tf_huggingface_module.layers
[<transformers.models....TFSegformerMainLayer at 0x7f2ad0>,
 <keras.layers.core.Dense at 0x7f2aa6ca3650>]

Now, what if I want to freeze only few layers from bottom to middle and unfree rest of it. How should I do that in huggingface API? FYI, In keras API, we can do something like this

# top 20 layers
 for layer in model.layers[-20:]:
        if not isinstance(layer, layers.BatchNormalization):
            layer.trainable = True

I have found this blog post, but need some precise pointer.

Hi,

Refer to this answer: How to freeze GPT-2 model layers with Tensorflow/Keras? · Issue #18282 · huggingface/transformers · GitHub

Hello, it looks like I need to look at the model building code to get the proper attribute name, for example (model.transformer.wte.trainable). Is there any documentation regarding this, for example, vision models in this case?

Hi,

from the link above:

To reach the layer you want to freeze, the best way is to navigate the code of the original model and find its attribute name.

So it’s advised to just check the implementation of the model. The implementation starts here. As you can see, TFViTModel contains a single attribute, “vit”. This leads to the TFViTMainLayer class. There we have the attributes “embeddings”, “encoder”, “layernorm” and “pooler”. So to freeze the weights of the embeddings for instance, you can do:

from transformers import TFViTModel

model = TFViTModel.from_pretrained("google/vit-base-patch16-224")

model.vit.embeddings.trainable = False

Thank you for the detailed answer. Yes, I understood that from the previous comments but I was expecting some kind of textual documentation on this. IMO, it’s a little weird to go to the source code for the relevant attribution name.