Transformers llama flash_attn_varlen_func questions

Hi, lately I am confused about flash_attn_varlen_func

if have attention mask below

attention_mask = torch.tensor(
    [
        [1,1,1,0,0],
        [1,1,0,0,0],
        [1,1,1,1,1],
    ],device='cuda:0'
)

according to llama model,

indices=[ 0,  1,  2,  5,  6, 10, 11, 12, 13, 14]
cu_seqlens=[ 0,  3,  5, 10] #which will be cu_seqlens_q pass to flash_attn_varlen_func
max_seqlen_in_batch=5

my implementation of attention mask is a little different

attention_mask = torch.tensor(
    [
        [0,2,3,torch.nan,torch.nan],
        [0,2,torch.nan,torch.nan,torch.nan],
        [0,5,torch.nan,torch.nan,torch.nan],
    ],device='cuda:0'
)

it means that if I have 2 sample(a,b) in batch 1, 2 tokens for a, 1 token for b. And I do not want b noticeawhen calculate attention

so I write my packing_attention_unpad below

def packing_attention_unpad(attention_mask:torch.Tensor):
        cu_seqlens_k = attention_mask.clone()
        for i in range(1,cu_seqlens_k.size(0)):
            cu_seqlens_k[i,:] += max(cu_seqlens_k[i-1])
        cu_seqlens_k = cu_seqlens_k[...,1:]
        cu_seqlens_k = cu_seqlens_k[torch.where(~torch.isnan(cu_seqlens_k))].to(dtype=torch.int32)
        cu_seqlens_k = torch.nn.functional.pad(cu_seqlens_k,(1,0))

        max_seqlen_in_batch_k = attention_mask[~torch.isnan(attention_mask)].max()
 
        indices_k = []
        seq_length = attention_mask.size(1)
        for i in range(attention_mask.size(0)):
            max_num = max(attention_mask[i,...])
            indices_k.extend([j + i*seq_length for j in range(int(max_num.item()))])
        indices_k = torch.tensor(indices_k, dtype=torch.int64,device=attention_mask.device)

        return indices_k, cu_seqlens_k, max_seqlen_in_batch_k

result:

indices=[ 0,  1,  2,  5,  6, 10, 11, 12, 13, 14]
cu_seqlens=[ 0, 2, 3,  5, 10] #which will be cu_seqlens_q pass to flash_attn_varlen_func
max_seqlen_in_batch=5
  1. am I doing right? I test it with llama, seems when batch_size=1 ,model works fine, but when batch_size>1, nan loss and non logits start coming at step 2… why does that happen ?
  2. according to cu_seqlens_q , cu_seqlens_q should have shape(batch_size+1,) my cu_seqlens have shape (batch_size+2,) , it did not give me error, so how it works?

Thanks in advance.