Decoder Causal Masking [Keras]

It seems like the padding mask issue arises from the fact that the English padding mask is being overwritten by the Spanish mask. A solution could be to combine both masks instead of replacing one with the other. I would suggest trying this modification to make sure that both the English and Spanish padding masks are applied correctly.

def call(self, inputs, encoder_outputs, mask=None):
    causal_mask = self.get_causal_attention_mask(inputs)
    if mask is not None:
        padding_mask = tf.cast(mask, dtype="int32")
        combined_mask = tf.minimum(padding_mask, causal_mask)
    else:
        combined_mask = causal_mask

    attention_output_1 = self.self_attention_1(
        query=inputs, value=inputs, key=inputs, attention_mask=combined_mask
    )
    attention_output_2 = self.self_attention_2(
        inputs=attention_output_1
    )
    return self.layernorm_3(attention_output_2)

This should ensure both the English and Spanish masks are respected without overriding one another.

Solution provided by Triskel Data Deterministic Ai

1 Like