Faulty encoder-decoder model where encoder is a Distilbert and decoder is a Keras BiLSTM

Hi all,

I have some questions about a transformer-based model in Keras following the encoder-decoder architecture where the encoder is a Huggingface Distilbert and the decoder is a BiLSTM with mask, with this mask (associated with input “mentions”) being a boolean mask indicating whether each token belongs to some disease code or not. The addressed problem revolves around token classification involving ICD coding, that is, given a text, try to predict the codes representing the existing diseases in the text. Let’s assume the example “The patient has flu”. Assuming that there is a token per word, the predicted output with 100% accuracy would be [“0”, “0”, “0”, “PREDICTED_FLU_CODE”].

For the time being, it is not necessary to recognise codes which were not seen in training (i.e. a zero-shot learning approach), so there are 750 possible codes belonging to different ailments in the training dataset. The skeleton of the model is this:

MAX_DOC_LENGTH = 256
categories_qty = 750
TRANSFORMER_MODEL = "distilbert-base-multilingual-cased"
distilbert_model = TFDistilBertForMaskedLM.from_pretrained(TRANSFORMER_MODEL,
												output_hidden_states=True,
												output_attentions=True)
input_ids = tf.keras.layers.Input(shape=(MAX_DOC_LENGTH,), dtype=tf.int32, name="input_ids")
attention_mask = tf.keras.layers.Input(shape=(MAX_DOC_LENGTH,), dtype=tf.int32, name="attention_mask")
labels = tf.keras.layers.Input(shape=(MAX_DOC_LENGTH,), dtype=tf.int32, name="labels")
distilbert_output = distilbert_model.distilbert({"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels})
mentions = tf.keras.layers.Input(shape=(MAX_DOC_LENGTH,), dtype=tf.bool, name="mentions")

intermediate_layer = tf.keras.layers.Dense(MAX_DOC_LENGTH, input_shape=(768,), activation=None, name="intermezzo_layer")(distilbert_output.last_hidden_state)
	bilstm_layer = tf.keras.layers.Bidirectional(
				tf.keras.layers.LSTM(units=MAX_DOC_LENGTH,
							 return_sequences=True),
				input_shape=(MAX_DOC_LENGTH, unique_categories_qty),
				initial_state=???,
				name="bidirectional")(intermediate_layer, training=True, mask=mentions)
	softmax = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(categories_qty, activation='softmax', name="softmax_layer"))(bilstm_layer)
	model = tf.keras.models.Model(inputs={"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "mentions": mentions}, outputs=softmax)

However, the model does not infer what it is intended to do. With the current configuration, it dismisses the tag that represents the vast majority of tokens which don’t correspond to any illness, assuming that every token belongs to a disease mention, but it does detect that an actual mention belongs to another code. Abbreviated example of this behaviour:

Test example #1:
decoded_y_true[0] value (length=256): ['0', '0', '0', '0', '0', '0', 'I46.9', '0', '0', '0', '0', '0', '0', '0']
decoded_y_pred[0] value (length=256): ['C80', 'C80', 'C80', 'C80', 'C80', 'C80', 'J44.1', 'C80', 'C80', 'C80', 'C80', 'C80', 'C80', 'C80', etc.]

I have some short questions/doubts concerning this model:

1 - Since Distilbert is playing the role of encoder here, I suppose that the right implementation here is TFDistilbertForMaskedLM (with MLM to fine-tune the general-purpose model and adapt it to such a biomedical context) as opposed to TFDistilbertForTokenClassification. Am I alright?

2 - How do I initialise the initial_state parameter in the LSTM? This is pretty straightforward when the encoder is another LSTM, however I have not found the proper way to do it if the encoder is a Huggingface transformer. On the other hand, must I attach the ìnitial_state directly from the hidden state of distilbert or using an intermediate layer to adapt the hidden size (default, 768) to MAX_DOC_LENGTH?

3 - I assume that the mask is a proper way to tell the LSTM where the attention should reside in. Is this right or should I do otherwise?

Please, feel free to ask for further details if I miss something important.
PD: I’m using TF datasets which load input_ids, attention_mask, labels and mentions mentions from a Pandas dataframe. I’ve also used Huggingface datasets and then to_tf_datasetmethod, but the behaviour is the same. Solutions as in y_test are one-hot encoded.

test_inputs = {"input_ids": tf.convert_to_tensor(test_input_ids),
"attention_mask": tf.convert_to_tensor(test_attention_mask), 
"labels": tf.convert_to_tensor(test_labels),
"mentions": test_mentions}

test_dataset = tf.data.Dataset.from_tensor_slices((test_inputs, tf.convert_to_tensor(y_test))