How to fine tune TFMT5ForConditionalGeneration for text classification?

Hi, I have a problem in fine-tuning TFMT5ForConditionalGeneration for text classification with Tensorflow ‘2.6.0’ and Transformers ‘4.11.2’.

My task is to classify text sentences to one of the severity levels (‘1’, ‘2’, ‘3’, ‘4’, ‘5’).


df = pd.read_csv(FILE, header=0, dtype=str, sep='\t', encoding='utf-8')
X_train, X_eval, y_train, y_eval = train_test_split(list(df.RPT_CNTS), list(df.RECV_EMG_CD), test_size=TEST_SPLIT)

tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small")
train_inputs = tokenizer(X_train, padding='max_length', truncation=True, max_length=100, return_tensors="tf")
train_labels = tokenizer(y_train, padding='max_length', truncation=True, max_length=2)
labels = train_labels.input_ids
labels = [
           [(label if label != 1 else -100) for label in labels_example] for labels_example in labels
]
train_inputs['labels'] = tf.convert_to_tensor(labels, dtype=tf.int32)

train_dataset = tf.data.Dataset.from_tensor_slices((
    dict(train_inputs),
    tf.convert_to_tensor(labels, dtype=tf.int32)
)).shuffle(10000).batch(128)

class TFT5Classifier(tf.keras.Model):

    def __init__(self, model_name):
        super(TFT5Classifier, self).__init__()
        self.t5 = TFMT5ForConditionalGeneration.from_pretrained(model_name)
        
    def call(self, inputs, attention_mask=None, labels=None, training=False):
        outputs = self.t5(inputs, attention_mask=attention_mask, labels=labels)
        return outputs.logits

model = TFT5Classifier('google/mt5-small')

optimizer = tf.keras.optimizers.Adam(2e-5)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
history = model.fit(train_dataset, epochs=1, batch_size=128)

However, it does not work as follows:


231/231 [==============================] - 131s 568ms/step - loss: nan - accuracy: 0.0000e+00
{'loss': [nan], 'accuracy': [0.0]}

Would you please help me out? Thank you!!