# get input sizes with logs
L_K = key_states.size(1)
L_Q = query_states.size(1)
log_L_K = np.ceil(np.log1p(L_K)).astype("int").item()
log_L_Q = np.ceil(np.log1p(L_Q)).astype("int").item()
# calculate a subset of samples to slice from K and create Q_K_sample
U_part = min(sampling_factor * L_Q * log_L_K, L_K)
# create Q_K_sample (the q_i * k_j^T term in the sparsity measurement)
index_sample = torch.randint(0, L_K, (U_part,))
K_sample = key_states[:, index_sample, :]
Q_K_sample = torch.bmm(query_states, K_sample.transpose(1, 2))
shouldn’t U_part be
U_part = min(sampling_factor * log_L_K, L_K)
without the L_Q? L_Q needs to be significantly smaller than L_K if we want U_part to be smaller than L_K and we want that, because when it’s equal to L_K, Q_K_sample is just Q @ K and avoiding that is the whole point.
Thank you for reading our blogpost so thoroughly, and for pointing out this issue!
Actually, at the beginging our implementation indeed was like you said, i.e.:
U_part = min(sampling_factor * log_L_K, L_K)
But, after looking more thoroughly in the original implemetion, U_part is set to c*m*ln(n) for stability issues, i.e.: