What is the right way of developing my own model based on a pretrained transformer?

I would like to play around with the code for WhisperForAudioClassification. I believe I can improve it and to prove that, I need to write my own code.

But of course, I don’t have the resources to train the model from scratch. Looking at this repo, this is how to load the classification model for fine-tuning:

model = AutoModelForAudioClassification.from_pretrained(
        model_args.model_name_or_path, // e.g. "openai/whisper-medium"
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
        ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)

...

The main part is AutoModelForAudioClassification which loads a model of WhisperForAudioClassification (code and weights). This is all good but if I fine tune the model this way, I’ll be doing that based on the current code. And I want to introduce my own code.

Basically, what I would like to learn is how to load the “openai/whisper-medium” weights into my own class. Of course, the weights and the class should be compatible. To expand on that, the WhisperForAudioClassification class adds a head to the encoder part of the Whisper. And I want to code a new/different head while keep using the same encoder. Needless to say, the pretrained weights coming from “openai/whisper-medium” will only populate the encoder and not the head.

Can some one please help learn how to use my own code populated with pretrained weights?

If you want to load the pretrained weights of whisper and use it as backbone and add subsequent classification layers as per your convenience, you could use the AutoModel class:

model = Automodel.from_pretrained(“model_name_or_path”)

Then you could pass in the inputs (input_ids, attention_masks, etc.) defined beforehand and pass it in the model to get the desired output.

outputs = model(**inputs)

Fetch the desired output element, say last_hidden_states and apply your classification head on the extracted embeddings.