Finetune whisper-tiny in german for tflite runtime

Helly everyone,

I’ve been trying to finetune a customized whisper-tiny model in german which I want to serve with the tflite runtime in a rust application for fast inference on mobile devices.
Loading the stock model from huggingface (openai/whisper-tiny at main) and converting it to tflite is working fine, but I haven’t found any solution to finetune the model in tensorflow format so that I get the model weights as .h5 file in the end which I can then load with TFWhisperForConditionalGeneration.from_pretrained() for the tflite conversion.

All finetuning tutorials I’ve found so far are pytorch based, and converting the model from pytorch over onnx to tensorflow format has resulted in asaved_model.pb file which I can’t load with TFWhisperForConditionalGeneration.from_pretrained() .

HuggingFace provides the .h5 format for the whisper-tiny model in above link, so I assume there must be a way to either finetune the model or convert it from pytorch to tensorflow format but I haven’t found a solution so far.

Does anyone have a hint to solve this issue?

My current pytorch-to-tf conversion script looks as follows:

import torch
import onnx
import librosa
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import onnx_tf
import numpy as np

np.bool = np.bool_

path = "openai/whisper-tiny"
# only model and processor need to be kept during training
processor = WhisperProcessor.from_pretrained(path)

model = WhisperForConditionalGeneration.from_pretrained(path)

# laod examplary input
y, sr = librosa.load("audio/en.wav", sr=16_000)
assert sr==16_000, "Only 16k sr supported"
input_features = processor(y, sampling_rate=sr, return_tensors="pt").input_features 
decoder_input_ids = torch.tensor([[50258]])

res = model(input_features, None, decoder_input_ids)
# print(res.logits.shape)
torch.onnx.export(model,
                  (input_features, None, decoder_input_ids),
                  'whisper.onnx',
                  input_names=['input_features',                         
                               'attention_mask',
                               'decoder_input_ids'],
                  output_names=['output'], 
                  opset_version=14)

onnx_model = onnx.load('whisper.onnx')

# print("Model Inputs: ", [inp.name for inp in onnx_model.graph.input])

tf_model = onnx_tf.backend.prepare(onnx_model)
tf_model.export_graph("whisper.tf")

requirements:

python 3.10.13
tensorflow 2.15.0
torch 2.1.4
transformers 4.45.2
1 Like

Scripts for HF staff to use for their own use may be available on github. And in many cases, there is no manual. The scripts below are just examples; explore github.

Thanks, will definitely take a look at it:)

1 Like