Question about the infernce flow for optimum exported decoder merged onnx model

Dear Huggingface Community

I’ve recently encountered some difficulties with model inference, specifically, I want to use numpy and onnxruntime (without relying on transformers and torch) to infer a VisionDecoderEncoder model exported by optimum (Spedon/texify-quantized-onnx · Hugging Face).

Based on my research during this time, I’ve managed to create a prototype (using the two onnx files from the link above), but the output differ significantly from the correct results obtained using optimum.pipeline (the first few token IDs are fine, but then it starts to go wrong).

==> Result from optimum.pipeline
"The potential $V_i$ of cell $\\mathcal{C}_i$ centred at position $\\mathbf{r}_i$ is related to the surface charge densities $\\sigma_j$ of cells $\\mathcal{C}_j$ $j\\in[1,N]$ through the superposition principle as: $$V_i\\,=\\,\\sum_{j=0}^{N}\\,\\frac{\\sigma_j}{4\\pi\\epsilon_0}\\,\\int_{\\mathcal{C}_j}\\frac{1}{\\|\\mathbf{r}_i-\\mathbf{r}^\\prime\\|}\\,\\mathrm{d}^2\\mathbf{r}^\\prime\\,=\\,\\sum_{j=0}^{N}\\,Q_{ij}\\,\\sigma_j,$$ where the integral over the surface of cell $\\mathcal{C}_j$ only depends on $\\mathcal{C}_j$ shape and on the relative position of the target point $\\mathbf{r}_i$ with respect to $\\mathcal{C}_j$ location, as $\\sigma_j$ is assumed constant over the whole surface of cell $\\mathcal{C}_j$."

==> Result from my impl
"The potential $\\rho_{\\rm H}$ is defined by\n\n$$\\rho_{\\rm H}=\\frac{1}{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right)^{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right)^{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right)^{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right)^{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right)^{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right)^{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right)^{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right)^{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right)^{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right)^{2}\\left(\\frac{\\rho_{\\rm H}}{\\rho_{\\rm H}}\\right"

Below is my implementation

import onnxruntime as ort
from transformers import AutoImageProcessor
from tokenizers import Tokenizer
import numpy as np
from PIL import Image
from numpy.typing import NDArray


class TxfOnnx:
    def __init__(self) -> None:
        self.encoder_session = ort.InferenceSession(
            "./qmodel/encoder_model_quantized.onnx"
        )
        self.decoder_session = ort.InferenceSession(
            "./qmodel/decoder_model_merged_quantized.onnx"
        )
        self.processor = AutoImageProcessor.from_pretrained("./qmodel")
        self.tokenizer: Tokenizer = Tokenizer.from_file("./qmodel/tokenizer.json")

    def preprocess(self, image: Image.Image) -> NDArray[np.float32]:
        return self.processor(image, return_tensors="np").pixel_values

    def encode(self, pixel_values: NDArray[np.float32]) -> NDArray[np.float32]:
        encoder_result: NDArray[np.float32] = self.encoder_session.run(
            None, {"pixel_values": pixel_values}
        )
        hidden_state: NDArray[np.float32] = encoder_result[0]
        return hidden_state

    def decode_without_cache(
        self, hidden_state: NDArray[np.float32]
    ) -> tuple[np.intp, dict[str, NDArray[np.float32]]]:
        default_kv = np.zeros((1, 16, 0, 64), dtype=np.float32)

        decoder_inputs = {}
        decoder_inputs["encoder_hidden_states"] = hidden_state
        decoder_inputs["input_ids"] = np.array([[0]], dtype=np.int64)  # start token
        decoder_inputs["use_cache_branch"] = np.array([False])

        for i in range(8):
            decoder_inputs[f"past_key_values.{i}.encoder.key"] = default_kv
            decoder_inputs[f"past_key_values.{i}.encoder.value"] = default_kv
            decoder_inputs[f"past_key_values.{i}.decoder.key"] = default_kv
            decoder_inputs[f"past_key_values.{i}.decoder.value"] = default_kv

        decoder_result: NDArray[np.float32] = self.decoder_session.run(
            None, decoder_inputs
        )
        logits: NDArray[np.float32] = decoder_result[0]
        next_id = np.argmax(logits[0, -1, :])

        return next_id, dict(
            zip(
                [node.name for node in self.decoder_session.get_outputs()],
                decoder_result,
            )
        )

    def decode_with_cache(
        self,
        hidden_state: NDArray[np.float32],
        input_ids: list[np.intp],
        last_decoder_result: dict[str, NDArray[np.float32]],
    ) -> tuple[np.intp, dict[str, NDArray[np.float32]]]:
        decoder_inputs = {}
        decoder_inputs["encoder_hidden_states"] = hidden_state
        decoder_inputs["input_ids"] = np.array([input_ids], dtype=np.int64)
        decoder_inputs["use_cache_branch"] = np.array([True])

        for i in range(8):
            decoder_inputs[f"past_key_values.{i}.encoder.key"] = last_decoder_result[
                f"present.{i}.encoder.key"
            ]
            decoder_inputs[f"past_key_values.{i}.encoder.value"] = last_decoder_result[
                f"present.{i}.encoder.value"
            ]
            decoder_inputs[f"past_key_values.{i}.decoder.key"] = last_decoder_result[
                f"present.{i}.decoder.key"
            ]
            decoder_inputs[f"past_key_values.{i}.decoder.value"] = last_decoder_result[
                f"present.{i}.decoder.value"
            ]

        decoder_result: NDArray[np.float32] = self.decoder_session.run(
            None, decoder_inputs
        )
        logits: NDArray[np.float32] = decoder_result[0]
        next_id = np.argmax(logits[0, -1, :])

        return next_id, dict(
            zip(
                [node.name for node in self.decoder_session.get_outputs()],
                decoder_result,
            )
        )

    def inference(self, image: Image.Image) -> str:
        MAX_LENGTH = 384

        pixel_values = self.preprocess(image)
        hidden_state = self.encode(pixel_values)

        result_idx = [np.intp(0)]  # start token id
        last_decoder_result = None

        while len(result_idx) < MAX_LENGTH:
            if last_decoder_result is None:
                next_id, last_decoder_result = self.decode_without_cache(hidden_state)
            else:
                next_id, last_decoder_result = self.decode_with_cache(
                    hidden_state, result_idx[-1:], last_decoder_result
                )
            result_idx.append(next_id)
            if next_id == 2:  # eos token
                break

        return self.tokenizer.decode(result_idx)


image = Image.open("./latex.png").convert("RGB")
inst = TxfOnnx()
print(inst.inference(image))

Later I found that if I use a more “traditional” way to infer, I can get the correct results, it’s just that it can’t utilize the cache, so the performance is poor.

import onnxruntime as ort
from transformers import AutoImageProcessor
from tokenizers import Tokenizer
import numpy as np
from PIL import Image
from numpy.typing import NDArray


class TxfOnnxSlow:
    def __init__(self) -> None:
        self.encoder_session = ort.InferenceSession(
            "./qmodel/encoder_model_quantized.onnx"
        )
        self.decoder_session = ort.InferenceSession(
            "./qmodel/decoder_model_merged_quantized.onnx"
        )
        self.processor = AutoImageProcessor.from_pretrained("./qmodel")
        self.tokenizer: Tokenizer = Tokenizer.from_file("./qmodel/tokenizer.json")

    def preprocess(self, image: Image.Image) -> NDArray[np.float32]:
        return self.processor(image, return_tensors="np").pixel_values

    def encode(self, pixel_values: NDArray[np.float32]) -> NDArray[np.float32]:
        encoder_result: NDArray[np.float32] = self.encoder_session.run(
            None, {"pixel_values": pixel_values}
        )
        hidden_state: NDArray[np.float32] = encoder_result[0]
        return hidden_state

    def decode_slow(
        self,
        hidden_state: NDArray[np.float32],
        input_ids: list[np.intp],
    ) -> np.intp:
        default_kv = np.zeros((1, 16, 0, 64), dtype=np.float32)

        decoder_inputs = {}
        decoder_inputs["encoder_hidden_states"] = hidden_state
        decoder_inputs["input_ids"] = np.array([input_ids], dtype=np.int64)
        decoder_inputs["use_cache_branch"] = np.array([False])

        for i in range(8):
            decoder_inputs[f"past_key_values.{i}.encoder.key"] = default_kv
            decoder_inputs[f"past_key_values.{i}.encoder.value"] = default_kv
            decoder_inputs[f"past_key_values.{i}.decoder.key"] = default_kv
            decoder_inputs[f"past_key_values.{i}.decoder.value"] = default_kv

        decoder_result: NDArray[np.float32] = self.decoder_session.run(
            None, decoder_inputs
        )
        logits: NDArray[np.float32] = decoder_result[0]
        next_id = np.argmax(logits[0, -1, :])

        return next_id

    def inference(self, image: Image.Image) -> str:
        MAX_LENGTH = 384

        pixel_values = self.preprocess(image)
        hidden_state = self.encode(pixel_values)

        result_idx = [np.intp(0)]  # start token id

        while len(result_idx) < MAX_LENGTH:
            next_id = self.decode_slow(hidden_state, result_idx)
            result_idx.append(next_id)
            if next_id == 2:
                break

        return self.tokenizer.decode(result_idx)


image = Image.open("./latex.png").convert("RGB")
inst = TxfOnnxSlow()
print(inst.inference(image))

May I know which process went wrong, or what the correct inference flow should be when using the merged decoder’s ONNX?

Based on my understanding, after obtaining the hidden state output from the encoder, what comes next is the first forward pass of the decoder, the use_cache_branch should be set to False. Then, the hidden state and the start token ID should be fed in to get the past_kv and the next token ID.

Once we got the past_kv, use_cache_branch can be set to True, input the hidden state, start token ID, and past_kv to get the new past_kv and next token ID, and then repeat this step.

If you have any questions about ONNX or Optimum, you can ask them in the following community Discussions and get a quick response. If you are in a hurry, you should do so.
If not, Forum is fine.

I figured it out by reading through transformers.js source.

When use_cache_branch is true, the decoder output present.X.encoder.value and present.X.encoder.key are both empty tensors (shape [0, 16, 1, 64]), and they can’t be directly fed into the next decode. These two keys and values should continue to use the results from the first decoder forward (where use_cache_branch is false) which has shape [1, 16, 196, 64].

A big shoutout to Xenova for their outstanding work!

1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.