Finetuning T5 on custom data

I have a small text dataset for translation which I want to fine-tune with t5-small , Here is the code which I am trying to use to finetune.

import numpy as np
import tensorflow as tf
from transformers import TFT5ForConditionalGeneration, T5Tokenizer

model = TFT5ForConditionalGeneration.from_pretrained('t5-small')
tokenizer = T5Tokenizer.from_pretrained('t5-small')

def data_gen():
  for _ in range(256):
    x = np.random.randint(1,tokenizer.vocab_size, model.config.n_positions)
    attention = np.ones_like(x)
    yield ((x, attention), (x, attention))

output_type = ((tf.int32, tf.int32), (tf.int32, tf.int32)) 
ds =, output_type).batch(2)

optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer=optimizer, loss=loss), epochs=3, steps_per_epoch=128)

For the above code I am getting the following error.

All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at t5-small.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.
Epoch 1/3
ValueError                                Traceback (most recent call last)
<ipython-input-12-12c0ab7ab337> in <module>()
     19 model.compile(optimizer=optimizer, loss=loss)
---> 21, epochs=3, steps_per_epoch=128)

10 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ in wrapper(*args, **kwargs)
    971           except Exception as e:  # pylint:disable=broad-except
    972             if hasattr(e, "ag_error_metadata"):
--> 973               raise e.ag_error_metadata.to_exception(e)
    974             else:
    975               raise

ValueError: in user code:

    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/ train_function  *
        return step_function(self, iterator)
    /usr/local/lib/python3.6/dist-packages/transformers/ call  *
        encoder_outputs = self.encoder(
    /usr/local/lib/python3.6/dist-packages/transformers/ call  *
        input_shape = shape_list(input_ids)
    /usr/local/lib/python3.6/dist-packages/transformers/ shape_list  *
        static = x.shape.as_list()
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ as_list  **
        raise ValueError("as_list() is not defined on an unknown TensorShape.")

    ValueError: as_list() is not defined on an unknown TensorShape