How to get T5 decoded logits using TFT5ForConditionalGeneration from encoded outputs?

I am trying to convert a TFT5ForConditionalGeneration with custom config into a TFLite model and as far I see, implementing a greedy approach seems to be the fastest approach on my own, but if you know a more straightforward approach, please do tell me.

I am currently trying to generate the decoder output using the encoder output which I will be generating only the first time when I am passing the entire sentence. And then, I tried to reuse this encoded vector for the rest of the greedy search as input for the decoder.

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, T5Config, TFT5ForConditionalGeneration

distill_config = T5Config(d_model=256, d_kv = 32, d_ff=512, num_heads=4, decoder_start_token_id=0)
tf_model = TFT5ForConditionalGeneration(config=distill_config)
tokenizer = AutoTokenizer.from_pretrained("t5-small", padding='max_length', truncation=True)
inputs =  tokenizer("this is a random input", return_tensors="tf")['input_ids']

encoder_outputs = tf_model.encoder(inputs)
decoder_input_ids = tf.convert_to_tensor(np.asarray([[0]]).astype(np.int32))
output = tf_model.decoder(decoder_input_ids = decoder_input_ids, encoder_outputs=encoder_outputs.last_hidden_state)


Error that I am getting when I run this code:

ValueError                                Traceback (most recent call last)
<ipython-input-5-face8f4fd36f> in <module>
     10 encoder_outputs = tf_model.encoder(inputs)
     11 decoder_input_ids = tf.convert_to_tensor(np.asarray([[0]]).astype(np.int32))
---> 12 output = tf_model.decoder(decoder_input_ids = decoder_input_ids, encoder_outputs=encoder_outputs.last_hidden_state)

1 frames
/usr/local/lib/python3.9/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     68             # To get the full stack trace, call:
     69             # `tf.debugging.disable_traceback_filtering()`
---> 70             raise e.with_traceback(filtered_tb) from None
     71         finally:
     72             del filtered_tb

/usr/local/lib/python3.9/dist-packages/keras/utils/layer_utils.py in split_out_first_arg(self, args, kwargs)
    807             inputs = kwargs.pop(self._arg_names[0])
    808         else:
--> 809             raise ValueError(
    810                 "The first argument to `Layer.call` must always be passed."
    811             )

ValueError: The first argument to `Layer.call` must always be passed.

So what is the right way to do this?

Without using an encoded vector, this gives me the required output:

import tensorflow as tf
from transformers import AutoTokenizer, T5Config, TFT5ForConditionalGeneration, set_seed
set_seed(0)
tokenizer = AutoTokenizer.from_pretrained("t5-small", padding='max_length', truncation=True)
tf_model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
inputs = tokenizer("i got permission to begin a start up company by my own..</s>",return_tensors='tf')
attn = inputs['attention_mask']

decoder_input = tf.zeros((1,1), dtype=tf.int64)
output = tf_model(input_ids=inputs['input_ids'], attention_mask = attn, decoder_input_ids=decoder_input).logits

print(tokenizer.batch_decode(output.numpy().argmax(-1).tolist()), output.numpy().argmax(-1).tolist())

Output:

[''] [[3]]

But I get a different answer when I try to use the encoded vector as below.

import tensorflow as tf
from transformers import AutoTokenizer, T5Config, TFT5ForConditionalGeneration, set_seed
set_seed(0)
tokenizer = AutoTokenizer.from_pretrained("t5-small", padding='max_length', truncation=True)
tf_model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
inputs = tokenizer("i got permission to begin a start up company by my own..</s>",return_tensors='tf')
attn = inputs['attention_mask']

encoder_outputs = tf_model.encoder(inputs['input_ids'], attention_mask = attn, return_dict = True)
output = tf_model.decoder(decoder_input, encoder_hidden_states=encoder_outputs.last_hidden_state).last_hidden_state

print(tokenizer.batch_decode(output.numpy().argmax(-1).tolist()), output.numpy().argmax(-1).tolist())

Output:

['une'] [[245]]

So how should I properly feed the input?