Augmenting a token classifier class to concatenate metadata to the hidden representation of the text, before the output classification layer

Hello all, (apologies for cross-posting; this is on SO as well)

I’m trying to implement a token classifier, but I want to use rich text metadata that I’ve got, in addition to the text itself.

The model I want to fit is the following:

label = dot_product([hidden_state, metadata], W)

where the hidden state is the output of TFDistilBertMainLayer (documented here)

I’m trying to tweak huggingface’s token classifier tutorial, which is here

Here is my start at modifying the TFDistilBertForTokenClassification class for my purposes:

import tensorflow as tf
from tensorflow.keras import Model, Input
from tensorflow.keras.layers import (concatenate, Dense, Dropout)
from transformers.modeling_tf_outputs import (
    TFSequenceClassifierOutput,
)
from transformers.modeling_tf_utils import (
    TFTokenClassificationLoss,
    get_initializer,
)
from transformers.tokenization_utils import BatchEncoding

from transformers import TFDistilBertPreTrainedModel, TFDistilBertMainLayer

unique_tags = list(range(5))

class my_cool_model(TFDistilBertPreTrainedModel, TFTokenClassificationLoss):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels
        self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
        inp_p = Input(shape=3)
        self.conc = concatenate([self.distilbert, inp_p])
        self.dropout = tf.keras.layers.Dropout(config.dropout)
        self.classifier = tf.keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )


    def call(
        self,
        inputs=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        labels=None,
        training=False,
    ):
        r"""
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
            1]``.
        """
        return_dict = return_dict if return_dict is not None else self.distilbert.return_dict
        if isinstance(inputs, (tuple, list)):
            labels = inputs[7] if len(inputs) > 7 else labels
            if len(inputs) > 7:
                inputs = inputs[:7]
        elif isinstance(inputs, (dict, BatchEncoding)):
            labels = inputs.pop("labels", labels)

        outputs = self.distilbert(
            inputs,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output, training=training)
        logits = self.classifier(sequence_output)

        loss = None if labels is None else self.compute_loss(labels, logits)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return TFTokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

model = my_cool_model.from_pretrained('distilbert-base-cased', num_labels=len(unique_tags))
model.summary()

It fails when it gets to the concatenate step, with the following error:

TypeError: 'NoneType' object is not subscriptable

I think that the problem is that the distilbert object doesn’t get treated like a regular tf/keras layer:

self.distilbert
<transformers.modeling_tf_distilbert.TFDistilBertMainLayer object at 0x7feabd9b8d50>

So, how would I go about building a token classifier that also has metadata here?

excuse me did you solve it ?

1 Like