Need advice for implementing Greedy Search for ORTModelForSeq2SeqLM

Hi,

Recently I converted my text translator [model] using Optimum Exporters and writing a code to load the model with OnnxRuntime and run it without the PyTorch backend.

Since it seems that removing torch dependency from Optimum is still in progress (optimum#526), I’m trying to implement some functions required to get inference results from my model.

While stepping into GenerationMixin.generate() function, I’ve found out that Greedy search (not the Beam search) is called to manipulate logits and decoder stuff.

Anyway, I think I’ve somewhat implemented Greedy search, and most of translation result is same with the result using GenerationMixin.generate().
(For example, the input sentence “よろしくお願いします.” correctly translated as “잘 부탁드립니다.”, which roughly means “thank you” or “nice to meet you”)
However there is a wrong result, for example, the input “ご飯を食べましょう.” should make an output “음, 이제 식사도 해볼까요”, but my greedy search makes an output “음, 음, 음, 음, 음, 음, 음, 음, 음, 음, 음, 음, 음, 음, 음, 음, 음, 음, 음, 음, 음, 음, 음, 음, 음,

The following code is my greedy_search() function and I’m lost what to change. Can anyone give me any advice to make this function generate nearly same result with transformers.generation.utils.GenerationMixin.greedy_search()?

def greedy_search(_input_data, _encoder_session, _decoder_session, _trg_tokenizer, max_length=50):
    # Assuming `input_ids` is the output from the encoder session
    # Initialize the input for the decoder
    _input_data['input_ids'] = np.array([[_trg_tokenizer.bos_token_id]]).astype(np.int64)

    # Initialize the list to store the generated tokens
    generated_tokens = []

    # Greedy search loop
    for _ in range(max_length):
        # Run the decoder model
        _decoder_output = _decoder_session.run(None, _input_data)

        # Update past_key_values with the current output
        if _decoder_output[1] is not None:
            _input_data['past_key_values.0.key'] = _decoder_output[1]
            _input_data['past_key_values.0.value'] = _decoder_output[2]
            _input_data['past_key_values.1.key'] = _decoder_output[3]
            _input_data['past_key_values.1.value'] = _decoder_output[4]
            _input_data['past_key_values.2.key'] = _decoder_output[5]
            _input_data['past_key_values.2.value'] = _decoder_output[6]
            _input_data['past_key_values.3.key'] = _decoder_output[7]
            _input_data['past_key_values.3.value'] = _decoder_output[8]
            _input_data['past_key_values.4.key'] = _decoder_output[9]
            _input_data['past_key_values.4.value'] = _decoder_output[10]
            _input_data['past_key_values.5.key'] = _decoder_output[11]
            _input_data['past_key_values.5.value'] = _decoder_output[12]
            _input_data['past_key_values.6.key'] = _decoder_output[13]
            _input_data['past_key_values.6.value'] = _decoder_output[14]
            _input_data['past_key_values.7.key'] = _decoder_output[15]
            _input_data['past_key_values.7.value'] = _decoder_output[16]
            _input_data['past_key_values.8.key'] = _decoder_output[17]
            _input_data['past_key_values.8.value'] = _decoder_output[18]
            _input_data['past_key_values.9.key'] = _decoder_output[19]
            _input_data['past_key_values.9.value'] = _decoder_output[20]
            _input_data['past_key_values.10.key'] = _decoder_output[21]
            _input_data['past_key_values.10.value'] = _decoder_output[22]
            _input_data['past_key_values.11.key'] = _decoder_output[23]
            _input_data['past_key_values.11.value'] = _decoder_output[24]
            _input_data['use_cache_branch'] = [True]

        # Extract the logits and apply softmax
        _logits = _decoder_output[0]
        _probabilities = np.exp(_logits) / np.sum(np.exp(_logits), axis=-1, keepdims=True)

        # Get the token with the highest probability
        next_token_id = np.argmax(_probabilities[:, -1, :], axis=-1).flatten()[0]

        # Append the token to the list
        generated_tokens.append(next_token_id)

        # Prepare the input for the next iteration
        _input_data['input_ids'] = np.array([[next_token_id]])

        # Check if EOS token is generated
        if next_token_id == _trg_tokenizer.eos_token_id:
            break

    # Decode the generated tokens into text
    _generated_text = _trg_tokenizer.decode(generated_tokens, skip_special_tokens=True)
    return _generated_text

Maybe the process of getting probabilities and tokens is too simplified…?
But I couldn’t find out corresponding code stuffs in GenerateMixin.greedy_search function…

        # Extract the logits and apply softmax
        _logits = _decoder_output[0]
        _probabilities = np.exp(_logits) / np.sum(np.exp(_logits), axis=-1, keepdims=True)

        # Get the token with the highest probability
        next_token_id = np.argmax(_probabilities[:, -1, :], axis=-1).flatten()[0]

You can check full code of my onnx implemention exercise_onnx_greedy.py in [here], which sometimes makes wrong result.

(If your Python environment have installed transformers, optimum, fugashi, unidic-lite and pretty_downloader package, you can run my code right away.)

Also there are same example code but written with Optimum(infer_onnx.py) or Pytorch(infer_torch.py), which makes correct result.
(I tried to attach the direct link but the system says I can add only two links since I’m a new user)

Sorry if I wrote this topic in wrong category.

Ok, I’ve found out what was wrong: I handled the past_key_values in a wrong way.

For the first loop (Session.run) in the Greedy Search, all the past_key_values should be zeros and use_cache_branch should be [False]. But from the 2nd loop, past_key_values needs to be updated with considering Cross Attention like below.

After fixing this, I got proper translation results successfully!

You can check my complete greedy search & inference code in my [repository].

1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.