Hi, lately I am confused about flash_attn_varlen_func
if have attention mask below
attention_mask = torch.tensor(
[
[1,1,1,0,0],
[1,1,0,0,0],
[1,1,1,1,1],
],device='cuda:0'
)
according to llama model,
indices=[ 0, 1, 2, 5, 6, 10, 11, 12, 13, 14]
cu_seqlens=[ 0, 3, 5, 10] #which will be cu_seqlens_q pass to flash_attn_varlen_func
max_seqlen_in_batch=5
my implementation of attention mask is a little different
attention_mask = torch.tensor(
[
[0,2,3,torch.nan,torch.nan],
[0,2,torch.nan,torch.nan,torch.nan],
[0,5,torch.nan,torch.nan,torch.nan],
],device='cuda:0'
)
it means that if I have 2 sample(a,b
) in batch 1, 2 tokens for a,
1 token for b
. And I do not want b
noticea
when calculate attention
so I write my packing_attention_unpad below
def packing_attention_unpad(attention_mask:torch.Tensor):
cu_seqlens_k = attention_mask.clone()
for i in range(1,cu_seqlens_k.size(0)):
cu_seqlens_k[i,:] += max(cu_seqlens_k[i-1])
cu_seqlens_k = cu_seqlens_k[...,1:]
cu_seqlens_k = cu_seqlens_k[torch.where(~torch.isnan(cu_seqlens_k))].to(dtype=torch.int32)
cu_seqlens_k = torch.nn.functional.pad(cu_seqlens_k,(1,0))
max_seqlen_in_batch_k = attention_mask[~torch.isnan(attention_mask)].max()
indices_k = []
seq_length = attention_mask.size(1)
for i in range(attention_mask.size(0)):
max_num = max(attention_mask[i,...])
indices_k.extend([j + i*seq_length for j in range(int(max_num.item()))])
indices_k = torch.tensor(indices_k, dtype=torch.int64,device=attention_mask.device)
return indices_k, cu_seqlens_k, max_seqlen_in_batch_k
result:
indices=[ 0, 1, 2, 5, 6, 10, 11, 12, 13, 14]
cu_seqlens=[ 0, 2, 3, 5, 10] #which will be cu_seqlens_q pass to flash_attn_varlen_func
max_seqlen_in_batch=5
- am I doing right? I test it with llama, seems when batch_size=1 ,model works fine, but when batch_size>1, nan loss and non logits start coming at step 2… why does that happen ?
- according to cu_seqlens_q ,
cu_seqlens_q
should have shape(batch_size+1,)
mycu_seqlens
have shape(batch_size+2,)
, it did not give me error, so how it works?
Thanks in advance.