Multi-task learning for masked language modeling and token classification

Hi,

I have a question regarding multi-task learning: I would like to do domain adaptation by using masked language modeling task but also want to use token classification as an auxiliary task. For token classification, I want to predict for each token if it is a part of any entity or not (binary prediction.) I implemented a custom BertForMultiTask class and would like to ask if I anyone else tried a similar approach and has a comment on this.

I created this custom class for using two heads by following the example for BertOnlyMLMHead :

class BertMultiTaskHeads(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.predictions = BertLMPredictionHead(config)
        # I add additional linear layer to do binary prediciton
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.binary_predictions = nn.Linear(config.hidden_size, 2)

    def forward(self, sequence_output):
        mlm_scores = self.predictions(sequence_output)
        # additional entity score for binary prediction
        entity_scores = self.binary_predictions(self.dropout(sequence_output))
        return mlm_scores, entity_scores

And I created BertForMultiTask class by following BertForMaskedLM (I have previously created a new BertMultiTaskOutput class for returning also the entity prediction loss) :

class BertForMultiTask(BertPreTrainedModel):
    _keys_to_ignore_on_load_unexpected = [r"pooler"]

    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config, add_pooling_layer=False)
        self.cls = BertMultiTaskHeads(config)
        self.init_weights()

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
            tag_labels=None,  # For entity binary prediction, obtained as an additional feature
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
            **kwargs
    ):

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]
        prediction_scores, label_scores = self.cls(sequence_output)

        total_loss = None
        if labels is not None and tag_labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
            # additional entity prediction loss
            entity_prediction_loss = loss_fct(label_scores.view(-1, 2), tag_labels.view(-1))
            total_loss = masked_lm_loss +  entity_prediction_loss

        if not return_dict:
            output = (prediction_scores, label_scores) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        return BertMultiTaskOutput(
            loss=total_loss,
            prediction_logits=prediction_scores,
            entity_prediction_logits=label_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

Thank you very much in advance!