I am trying to convert a TFT5ForConditionalGeneration with custom config into a TFLite model and as far I see, implementing a greedy approach seems to be the fastest approach on my own, but if you know a more straightforward approach, please do tell me.
I am currently trying to generate the decoder output using the encoder output which I will be generating only the first time when I am passing the entire sentence. And then, I tried to reuse this encoded vector for the rest of the greedy search as input for the decoder.
import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, T5Config, TFT5ForConditionalGeneration
distill_config = T5Config(d_model=256, d_kv = 32, d_ff=512, num_heads=4, decoder_start_token_id=0)
tf_model = TFT5ForConditionalGeneration(config=distill_config)
tokenizer = AutoTokenizer.from_pretrained("t5-small", padding='max_length', truncation=True)
inputs = tokenizer("this is a random input", return_tensors="tf")['input_ids']
encoder_outputs = tf_model.encoder(inputs)
decoder_input_ids = tf.convert_to_tensor(np.asarray([[0]]).astype(np.int32))
output = tf_model.decoder(decoder_input_ids = decoder_input_ids, encoder_outputs=encoder_outputs.last_hidden_state)
Error that I am getting when I run this code:
ValueError Traceback (most recent call last)
<ipython-input-5-face8f4fd36f> in <module>
10 encoder_outputs = tf_model.encoder(inputs)
11 decoder_input_ids = tf.convert_to_tensor(np.asarray([[0]]).astype(np.int32))
---> 12 output = tf_model.decoder(decoder_input_ids = decoder_input_ids, encoder_outputs=encoder_outputs.last_hidden_state)
1 frames
/usr/local/lib/python3.9/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
68 # To get the full stack trace, call:
69 # `tf.debugging.disable_traceback_filtering()`
---> 70 raise e.with_traceback(filtered_tb) from None
71 finally:
72 del filtered_tb
/usr/local/lib/python3.9/dist-packages/keras/utils/layer_utils.py in split_out_first_arg(self, args, kwargs)
807 inputs = kwargs.pop(self._arg_names[0])
808 else:
--> 809 raise ValueError(
810 "The first argument to `Layer.call` must always be passed."
811 )
ValueError: The first argument to `Layer.call` must always be passed.
So what is the right way to do this?