Helly everyone,
I’ve been trying to finetune a customized whisper-tiny model in german which I want to serve with the tflite runtime in a rust application for fast inference on mobile devices.
Loading the stock model from huggingface (openai/whisper-tiny at main) and converting it to tflite is working fine, but I haven’t found any solution to finetune the model in tensorflow format so that I get the model weights as .h5
file in the end which I can then load with TFWhisperForConditionalGeneration.from_pretrained()
for the tflite conversion.
All finetuning tutorials I’ve found so far are pytorch based, and converting the model from pytorch over onnx to tensorflow format has resulted in asaved_model.pb
file which I can’t load with TFWhisperForConditionalGeneration.from_pretrained()
.
HuggingFace provides the .h5
format for the whisper-tiny model in above link, so I assume there must be a way to either finetune the model or convert it from pytorch to tensorflow format but I haven’t found a solution so far.
Does anyone have a hint to solve this issue?
My current pytorch-to-tf conversion script looks as follows:
import torch
import onnx
import librosa
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import onnx_tf
import numpy as np
np.bool = np.bool_
path = "openai/whisper-tiny"
# only model and processor need to be kept during training
processor = WhisperProcessor.from_pretrained(path)
model = WhisperForConditionalGeneration.from_pretrained(path)
# laod examplary input
y, sr = librosa.load("audio/en.wav", sr=16_000)
assert sr==16_000, "Only 16k sr supported"
input_features = processor(y, sampling_rate=sr, return_tensors="pt").input_features
decoder_input_ids = torch.tensor([[50258]])
res = model(input_features, None, decoder_input_ids)
# print(res.logits.shape)
torch.onnx.export(model,
(input_features, None, decoder_input_ids),
'whisper.onnx',
input_names=['input_features',
'attention_mask',
'decoder_input_ids'],
output_names=['output'],
opset_version=14)
onnx_model = onnx.load('whisper.onnx')
# print("Model Inputs: ", [inp.name for inp in onnx_model.graph.input])
tf_model = onnx_tf.backend.prepare(onnx_model)
tf_model.export_graph("whisper.tf")
requirements:
python 3.10.13
tensorflow 2.15.0
torch 2.1.4
transformers 4.45.2