Dear Huggingface Community
I’ve recently encountered some difficulties with model inference, specifically, I want to use numpy and onnxruntime (without relying on transformers and torch) to infer a VisionDecoderEncoder model exported by optimum (Spedon/texify-quantized-onnx · Hugging Face).
Based on my research during this time, I’ve managed to create a prototype (using the two onnx files from the link above), but the output differ significantly from the correct results obtained using optimum.pipeline
(the first few token IDs are fine, but then it starts to go wrong).
==> Result from optimum.pipeline
"The potential $V_i$ of cell $\\mathcal{C}_i$ centred at position $\\mathbf{r}_i$ is related to the surface charge densities $\\sigma_j$ of cells $\\mathcal{C}_j$ $j\\in[1,N]$ through the superposition principle as: $$V_i\\,=\\,\\sum_{j=0}^{N}\\,\\frac{\\sigma_j}{4\\pi\\epsilon_0}\\,\\int_{\\mathcal{C}_j}\\frac{1}{\\|\\mathbf{r}_i-\\mathbf{r}^\\prime\\|}\\,\\mathrm{d}^2\\mathbf{r}^\\prime\\,=\\,\\sum_{j=0}^{N}\\,Q_{ij}\\,\\sigma_j,$$ where the integral over the surface of cell $\\mathcal{C}_j$ only depends on $\\mathcal{C}_j$ shape and on the relative position of the target point $\\mathbf{r}_i$ with respect to $\\mathcal{C}_j$ location, as $\\sigma_j$ is assumed constant over the whole surface of cell $\\mathcal{C}_j$."
==> Result from my impl
"The potential $\\rho_{\\rm H}$ is defined by\n\n$$\\rho_{\\rm H}=\\frac{1}{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right)^{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right)^{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right)^{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right)^{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right)^{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right)^{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right)^{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right)^{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right)^{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right)^{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right"
Below is my implementation
import onnxruntime as ort
from transformers import AutoImageProcessor
from tokenizers import Tokenizer
import numpy as np
from PIL import Image
from numpy.typing import NDArray
class TxfOnnx:
def __init__(self) -> None:
self.encoder_session = ort.InferenceSession(
"./qmodel/encoder_model_quantized.onnx"
)
self.decoder_session = ort.InferenceSession(
"./qmodel/decoder_model_merged_quantized.onnx"
)
self.processor = AutoImageProcessor.from_pretrained("./qmodel")
self.tokenizer: Tokenizer = Tokenizer.from_file("./qmodel/tokenizer.json")
def preprocess(self, image: Image.Image) -> NDArray[np.float32]:
return self.processor(image, return_tensors="np").pixel_values
def encode(self, pixel_values: NDArray[np.float32]) -> NDArray[np.float32]:
encoder_result: NDArray[np.float32] = self.encoder_session.run(
None, {"pixel_values": pixel_values}
)
hidden_state: NDArray[np.float32] = encoder_result[0]
return hidden_state
def decode_without_cache(
self, hidden_state: NDArray[np.float32]
) -> tuple[np.intp, dict[str, NDArray[np.float32]]]:
default_kv = np.zeros((1, 16, 0, 64), dtype=np.float32)
decoder_inputs = {}
decoder_inputs["encoder_hidden_states"] = hidden_state
decoder_inputs["input_ids"] = np.array([[0]], dtype=np.int64) # start token
decoder_inputs["use_cache_branch"] = np.array([False])
for i in range(8):
decoder_inputs[f"past_key_values.{i}.encoder.key"] = default_kv
decoder_inputs[f"past_key_values.{i}.encoder.value"] = default_kv
decoder_inputs[f"past_key_values.{i}.decoder.key"] = default_kv
decoder_inputs[f"past_key_values.{i}.decoder.value"] = default_kv
decoder_result: NDArray[np.float32] = self.decoder_session.run(
None, decoder_inputs
)
logits: NDArray[np.float32] = decoder_result[0]
next_id = np.argmax(logits[0, -1, :])
return next_id, dict(
zip(
[node.name for node in self.decoder_session.get_outputs()],
decoder_result,
)
)
def decode_with_cache(
self,
hidden_state: NDArray[np.float32],
input_ids: list[np.intp],
last_decoder_result: dict[str, NDArray[np.float32]],
) -> tuple[np.intp, dict[str, NDArray[np.float32]]]:
decoder_inputs = {}
decoder_inputs["encoder_hidden_states"] = hidden_state
decoder_inputs["input_ids"] = np.array([input_ids], dtype=np.int64)
decoder_inputs["use_cache_branch"] = np.array([True])
for i in range(8):
decoder_inputs[f"past_key_values.{i}.encoder.key"] = last_decoder_result[
f"present.{i}.encoder.key"
]
decoder_inputs[f"past_key_values.{i}.encoder.value"] = last_decoder_result[
f"present.{i}.encoder.value"
]
decoder_inputs[f"past_key_values.{i}.decoder.key"] = last_decoder_result[
f"present.{i}.decoder.key"
]
decoder_inputs[f"past_key_values.{i}.decoder.value"] = last_decoder_result[
f"present.{i}.decoder.value"
]
decoder_result: NDArray[np.float32] = self.decoder_session.run(
None, decoder_inputs
)
logits: NDArray[np.float32] = decoder_result[0]
next_id = np.argmax(logits[0, -1, :])
return next_id, dict(
zip(
[node.name for node in self.decoder_session.get_outputs()],
decoder_result,
)
)
def inference(self, image: Image.Image) -> str:
MAX_LENGTH = 384
pixel_values = self.preprocess(image)
hidden_state = self.encode(pixel_values)
result_idx = [np.intp(0)] # start token id
last_decoder_result = None
while len(result_idx) < MAX_LENGTH:
if last_decoder_result is None:
next_id, last_decoder_result = self.decode_without_cache(hidden_state)
else:
next_id, last_decoder_result = self.decode_with_cache(
hidden_state, result_idx[-1:], last_decoder_result
)
result_idx.append(next_id)
if next_id == 2: # eos token
break
return self.tokenizer.decode(result_idx)
image = Image.open("./latex.png").convert("RGB")
inst = TxfOnnx()
print(inst.inference(image))
Later I found that if I use a more “traditional” way to infer, I can get the correct results, it’s just that it can’t utilize the cache, so the performance is poor.
import onnxruntime as ort
from transformers import AutoImageProcessor
from tokenizers import Tokenizer
import numpy as np
from PIL import Image
from numpy.typing import NDArray
class TxfOnnxSlow:
def __init__(self) -> None:
self.encoder_session = ort.InferenceSession(
"./qmodel/encoder_model_quantized.onnx"
)
self.decoder_session = ort.InferenceSession(
"./qmodel/decoder_model_merged_quantized.onnx"
)
self.processor = AutoImageProcessor.from_pretrained("./qmodel")
self.tokenizer: Tokenizer = Tokenizer.from_file("./qmodel/tokenizer.json")
def preprocess(self, image: Image.Image) -> NDArray[np.float32]:
return self.processor(image, return_tensors="np").pixel_values
def encode(self, pixel_values: NDArray[np.float32]) -> NDArray[np.float32]:
encoder_result: NDArray[np.float32] = self.encoder_session.run(
None, {"pixel_values": pixel_values}
)
hidden_state: NDArray[np.float32] = encoder_result[0]
return hidden_state
def decode_slow(
self,
hidden_state: NDArray[np.float32],
input_ids: list[np.intp],
) -> np.intp:
default_kv = np.zeros((1, 16, 0, 64), dtype=np.float32)
decoder_inputs = {}
decoder_inputs["encoder_hidden_states"] = hidden_state
decoder_inputs["input_ids"] = np.array([input_ids], dtype=np.int64)
decoder_inputs["use_cache_branch"] = np.array([False])
for i in range(8):
decoder_inputs[f"past_key_values.{i}.encoder.key"] = default_kv
decoder_inputs[f"past_key_values.{i}.encoder.value"] = default_kv
decoder_inputs[f"past_key_values.{i}.decoder.key"] = default_kv
decoder_inputs[f"past_key_values.{i}.decoder.value"] = default_kv
decoder_result: NDArray[np.float32] = self.decoder_session.run(
None, decoder_inputs
)
logits: NDArray[np.float32] = decoder_result[0]
next_id = np.argmax(logits[0, -1, :])
return next_id
def inference(self, image: Image.Image) -> str:
MAX_LENGTH = 384
pixel_values = self.preprocess(image)
hidden_state = self.encode(pixel_values)
result_idx = [np.intp(0)] # start token id
while len(result_idx) < MAX_LENGTH:
next_id = self.decode_slow(hidden_state, result_idx)
result_idx.append(next_id)
if next_id == 2:
break
return self.tokenizer.decode(result_idx)
image = Image.open("./latex.png").convert("RGB")
inst = TxfOnnxSlow()
print(inst.inference(image))
May I know which process went wrong, or what the correct inference flow should be when using the merged decoder’s ONNX?