Training loss is not decreasing using TFBertModel

I have used the TFBertModel and AutoModel from the transformer library for training a two-class classification task and the training loss is not decreasing.

bert = TFBertModel.from_pretrained('bert-base-uncased')
input_ids = tf.keras.layers.Input(shape=(SEQ_LEN,), name='input_ids', dtype='int32')
mask = tf.keras.layers.Input(shape=(SEQ_LEN,), name='attention_mask', dtype='int32')
embeddings = bert(input_ids, attention_mask=mask)[1]
X = tf.keras.layers.Dropout(0.1)(embeddings)
X = tf.keras.layers.Dense(128, activation='relu')(X)
y = tf.keras.layers.Dense(1, activation='sigmoid', name='outputs')(X)
bert_model = tf.keras.Model(inputs=[input_ids, mask], outputs=y)
bert_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

But when I use the TFBertForSequenceClassification model the model converges fast and the training loss reaches zero.

bert_model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased',num_labels=2)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
optimizer = tf.keras.optimizers.Adam(learning_rate=2e-5,epsilon=1e-08)
bert_model.compile(loss=loss, optimizer=optimizer, metrics=[metric])

I want to use the sequence output of BERT and hence I need to load the model with TFBertModel or something similar which returns the outputs of BERT.

@Rocketknight1

Your code in the first block looks like it would work. I suspect the problem is one of two things:

  1. The dropout may prevent the model from fully converging.
  2. The default learning rate for Adam is 1e-3, which is much too high for training Transformer models. Try learning rates in the range 1e-5 to 1e-4. If training loss is still not decreasing even with a lower learning rate and no dropout then let me know and I’ll investigate further.
1 Like

Thank you for your reply. I eliminated the dropout layer and changed the learning_rate value to smaller value and it worked!