Quick question on attention masking in transformer models

I have been trying to understand if both causal and attention masks are required during text generation inference.

Here is my reasoning.

Considering Llama code’s forward function. during training and inference, the _update_causal_mask function is called.

I am using sdpa attention and Dynamic Cache and the _ignore_causal_mask_sdpa function is in turn invoked.

During inference, this condition satisfies for all the sequentially generated tokens and hence the causal mask is always None. However, in training, this is not the case and so a causal mask does get created.

However, it is this causal mask that is passed as attention mask downstream.

This means for batch inference, the mask would still be None.

Questions:

  1. Does this mean causal mask will always be None for inference?
  2. How are the attention masks used during batched inference?

Insights into these will be very helpful! Thanks

1 Like