TFBertForSeqClassification for multilabel classification

I am trying to fine-tune a bert model for multi-label classification. the entire codeset is available on this colab notebook

here is how my data looks like.

({'input_ids': <tf.Tensor: shape=(128,), dtype=int32, numpy=
array([    2,  8318,  1379,  7892,  2791, 20630,     1,     4,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0], dtype=int32)>, 
'attention_mask': <tf.Tensor: shape=(128,), dtype=int32, numpy=
array([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)>}, 
<tf.Tensor: shape=(7,), dtype=int64, numpy=array([1, 0, 0, 0, 0, 0, 0])>)

The first element is the id,
The second element corresponds to the attention_masks
the third one are the labels - here I have 7 lables.

First effort:

    MODEL_NAME_OR_PATH = 'HooshvareLab/bert-fa-base-uncased'
    NUM_LABELS = 7
    
    from transformers import TFBertForSequenceClassification, BertConfig
    model = TFBertForSequenceClassification.from_pretrained(
        MODEL_NAME_OR_PATH, 
        config=BertConfig.from_pretrained(MODEL_NAME_OR_PATH, num_labels=NUM_LABELS, problem_type="multi_label_classification")
        )
    
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    model.compile(optimizer='adam', loss=loss, metrics=['accuracy'])
    history = model.fit(train_dataset, epochs=1, steps_per_epoch=115, validation_data=valid_dataset, validation_steps=7)

which ends up with the following error

InvalidArgumentError                      Traceback (most recent call last)

<ipython-input-48-4408a1f17fbe> in <module>()
     10 loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
     11 model.compile(optimizer='adam', loss=loss, metrics=['accuracy'])
---> 12 history = model.fit(train_dataset, epochs=1, steps_per_epoch=115, validation_data=valid_dataset, validation_steps=7)
     13 
     14 

1 frames

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     53     ctx.ensure_initialized()
     54     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 55                                         inputs, attrs, num_outputs)
     56   except core._NotOkStatusException as e:
     57     if name is not None:

InvalidArgumentError: Graph execution error:

Detected at node 'Equal' defined at (most recent call last):
    File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
      "__main__", mod_spec)
    File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
      exec(code, run_globals)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py", line 16, in <module>
      app.launch_new_instance()
    File "/usr/local/lib/python3.7/dist-packages/traitlets/config/application.py", line 846, in launch_instance
      app.start()
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelapp.py", line 499, in start
      self.io_loop.start()
    File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 132, in start
      self.asyncio_loop.run_forever()
    File "/usr/lib/python3.7/asyncio/base_events.py", line 541, in run_forever
      self._run_once()
    File "/usr/lib/python3.7/asyncio/base_events.py", line 1786, in _run_once
      handle._run()
    File "/usr/lib/python3.7/asyncio/events.py", line 88, in _run
      self._context.run(self._callback, *self._args)
    File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 122, in _handle_events
      handler_func(fileobj, events)
    File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 577, in _handle_events
      self._handle_recv()
    File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 606, in _handle_recv
      self._run_callback(callback, msg)
    File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 556, in _run_callback
      callback(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 283, in dispatcher
      return self.dispatch_shell(stream, msg)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
      handler(stream, idents, msg)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 399, in execute_request
      user_expressions, allow_stdin)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/ipkernel.py", line 208, in do_execute
      res = shell.run_cell(code, store_history=store_history, silent=silent)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/zmqshell.py", line 537, in run_cell
      return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2718, in run_cell
      interactivity=interactivity, compiler=compiler, result=result)
    File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2822, in run_ast_nodes
      if self.run_code(code, result):
    File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "<ipython-input-48-4408a1f17fbe>", line 12, in <module>
      history = model.fit(train_dataset, epochs=1, steps_per_epoch=115, validation_data=valid_dataset, validation_steps=7)
    File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1384, in fit
      tmp_logs = self.train_function(iterator)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1021, in train_function
      return step_function(self, iterator)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1010, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1000, in run_step
      outputs = model.train_step(data)
    File "/usr/local/lib/python3.7/dist-packages/transformers/modeling_tf_utils.py", line 1156, in train_step
      self.compiled_metrics.update_state(y, y_pred, sample_weight)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/compile_utils.py", line 459, in update_state
      metric_obj.update_state(y_t, y_p, sample_weight=mask)
    File "/usr/local/lib/python3.7/dist-packages/keras/utils/metrics_utils.py", line 70, in decorated
      update_op = update_state_fn(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/keras/metrics.py", line 178, in update_state_fn
      return ag_update_state(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/keras/metrics.py", line 729, in update_state
      matches = ag_fn(y_true, y_pred, **self._fn_kwargs)
    File "/usr/local/lib/python3.7/dist-packages/keras/metrics.py", line 4086, in sparse_categorical_accuracy
      return tf.cast(tf.equal(y_true, y_pred), backend.floatx())
Node: 'Equal'
required broadcastable shapes
	 [[{{node Equal}}]] [Op:__inference_train_function_187978]

Second Effort** inspired by the code below

from transformers import TFBertPreTrainedModel
from transformers import TFBertMainLayer


class TFBertForMultilabelClassification(TFBertPreTrainedModel):

    def __init__(self, config, *inputs, **kwargs):
        super(TFBertForMultilabelClassification, self).__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels

        self.bert = TFBertMainLayer(config, name='bert')
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.classifier = tf.keras.layers.Dense(config.num_labels,
                                                kernel_initializer='random_normal', #get_initializer(config.initializer_range),
                                                name='classifier',
                                                activation='sigmoid')

    def call(self, inputs, **kwargs):
        outputs = self.bert(inputs, **kwargs)

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output, training=kwargs.get('training', False))
        logits = self.classifier(pooled_output)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

        return outputs  # logits, (hidden_states), (attentions)


MODEL_NAME_OR_PATH = 'HooshvareLab/bert-fa-base-uncased'
NUM_LABELS = len(y_train[0])

model = TFBertForMultilabelClassification.from_pretrained(MODEL_NAME_OR_PATH, num_labels=NUM_LABELS)

optimizer = tf.keras.optimizers.Adam(learning_rate=0.001,epsilon=1e-08, clipnorm=1)
# we do not have one-hot vectors, we can use sparce categorical cross entropy and accuracy
loss = tf.keras.losses.BinaryCrossentropy()
metric = tf.keras.metrics.CategoricalAccuracy()
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
history = model.fit(train_dataset, epochs=1, validation_data=valid_dataset)

returns the following error

InvalidArgumentError                      Traceback (most recent call last)

<ipython-input-49-8aa1173bef76> in <module>()
      4 metric = tf.keras.metrics.CategoricalAccuracy()
      5 model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
----> 6 history = model.fit(train_dataset, epochs=1, validation_data=valid_dataset)

1 frames

    /usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
         53     ctx.ensure_initialized()
         54     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
    ---> 55                                         inputs, attrs, num_outputs)
         56   except core._NotOkStatusException as e:
         57     if name is not None:
    
    InvalidArgumentError: Graph execution error:
    
    Detected at node 'Equal' defined at (most recent call last):
        File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
          "__main__", mod_spec)
        File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
          exec(code, run_globals)
        File "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py", line 16, in <module>
          app.launch_new_instance()
        File "/usr/local/lib/python3.7/dist-packages/traitlets/config/application.py", line 846, in launch_instance
          app.start()
        File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelapp.py", line 499, in start
          self.io_loop.start()
        File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 132, in start
          self.asyncio_loop.run_forever()
        File "/usr/lib/python3.7/asyncio/base_events.py", line 541, in run_forever
          self._run_once()
        File "/usr/lib/python3.7/asyncio/base_events.py", line 1786, in _run_once
          handle._run()
        File "/usr/lib/python3.7/asyncio/events.py", line 88, in _run
          self._context.run(self._callback, *self._args)
        File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 122, in _handle_events
          handler_func(fileobj, events)
        File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
          return fn(*args, **kwargs)
        File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 577, in _handle_events
          self._handle_recv()
        File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 606, in _handle_recv
          self._run_callback(callback, msg)
        File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 556, in _run_callback
          callback(*args, **kwargs)
        File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
          return fn(*args, **kwargs)
        File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 283, in dispatcher
          return self.dispatch_shell(stream, msg)
        File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
          handler(stream, idents, msg)
        File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 399, in execute_request
          user_expressions, allow_stdin)
        File "/usr/local/lib/python3.7/dist-packages/ipykernel/ipkernel.py", line 208, in do_execute
          res = shell.run_cell(code, store_history=store_history, silent=silent)
        File "/usr/local/lib/python3.7/dist-packages/ipykernel/zmqshell.py", line 537, in run_cell
          return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
        File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2718, in run_cell
          interactivity=interactivity, compiler=compiler, result=result)
        File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2822, in run_ast_nodes
          if self.run_code(code, result):
        File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
          exec(code_obj, self.user_global_ns, self.user_ns)
        File "<ipython-input-49-8aa1173bef76>", line 6, in <module>
          history = model.fit(train_dataset, epochs=1, validation_data=valid_dataset)
        File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
          return fn(*args, **kwargs)
        File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1384, in fit
          tmp_logs = self.train_function(iterator)
        File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1021, in train_function
          return step_function(self, iterator)
        File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1010, in step_function
          outputs = model.distribute_strategy.run(run_step, args=(data,))
        File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1000, in run_step
          outputs = model.train_step(data)
        File "/usr/local/lib/python3.7/dist-packages/transformers/modeling_tf_utils.py", line 1156, in train_step
          self.compiled_metrics.update_state(y, y_pred, sample_weight)
        File "/usr/local/lib/python3.7/dist-packages/keras/engine/compile_utils.py", line 459, in update_state
          metric_obj.update_state(y_t, y_p, sample_weight=mask)
        File "/usr/local/lib/python3.7/dist-packages/keras/utils/metrics_utils.py", line 70, in decorated
          update_op = update_state_fn(*args, **kwargs)
        File "/usr/local/lib/python3.7/dist-packages/keras/metrics.py", line 178, in update_state_fn
          return ag_update_state(*args, **kwargs)
        File "/usr/local/lib/python3.7/dist-packages/keras/metrics.py", line 729, in update_state
          matches = ag_fn(y_true, y_pred, **self._fn_kwargs)
        File "/usr/local/lib/python3.7/dist-packages/keras/metrics.py", line 4086, in sparse_categorical_accuracy
          return tf.cast(tf.equal(y_true, y_pred), backend.floatx())
    Node: 'Equal'
    required broadcastable shapes
    	 [[{{node Equal}}]] [Op:__inference_train_function_214932]

I hope I believe given major changes both in tf2 and (TF-based) huggingface transformers

I also have asked the question on stackoverflow as I believe there are not many material covering the latest changes