I have been trying to understand if both causal and attention masks are required during text generation inference.
Here is my reasoning.
Considering Llama code’s forward function. during training and inference, the _update_causal_mask
function is called.
I am using sdpa
attention and Dynamic Cache
and the _ignore_causal_mask_sdpa
function is in turn invoked.
During inference, this condition satisfies for all the sequentially generated tokens and hence the causal mask is always None
. However, in training, this is not the case and so a causal mask does get created.
However, it is this causal mask that is passed as attention mask downstream.
This means for batch inference, the mask would still be None
.
Questions:
- Does this mean causal mask will always be
None
for inference? - How are the attention masks used during batched inference?
Insights into these will be very helpful! Thanks