Probsparse_attention in Informer

Hi
I was reading this blog post Multivariate Probabilistic Time Series Forecasting with Informer and it has a section where probsparse_attention is implemented and I think it has a bug.

    # get input sizes with logs
    L_K = key_states.size(1)
    L_Q = query_states.size(1)
    log_L_K = np.ceil(np.log1p(L_K)).astype("int").item()
    log_L_Q = np.ceil(np.log1p(L_Q)).astype("int").item()

    # calculate a subset of samples to slice from K and create Q_K_sample
    U_part = min(sampling_factor * L_Q * log_L_K, L_K)

    # create Q_K_sample (the q_i * k_j^T term in the sparsity measurement)
    index_sample = torch.randint(0, L_K, (U_part,))
    K_sample = key_states[:, index_sample, :]
    Q_K_sample = torch.bmm(query_states, K_sample.transpose(1, 2))

shouldn’t U_part be

U_part = min(sampling_factor  * log_L_K, L_K)

without the L_Q? L_Q needs to be significantly smaller than L_K if we want U_part to be smaller than L_K and we want that, because when it’s equal to L_K, Q_K_sample is just Q @ K and avoiding that is the whole point.

Hi,

I’ve pinged @kashif and Eli to take a look!

Indeed it’s a bug in the pseudo-code of the blog… If you want to you can open a PR to fix the blog-post or I will do it by tomorrow.

Thanks again!

Hi Gozdi,

Thank you for reading our blogpost so thoroughly, and for pointing out this issue!
Actually, at the beginging our implementation indeed was like you said, i.e.:

U_part = min(sampling_factor * log_L_K, L_K)

But, after looking more thoroughly in the original implemetion, U_part is set to c*m*ln(n) for stability issues, i.e.:

U_part = min(sampling_factor * L_Q * log_L_K, L_K)

Citing the authors:

we set U_part to c*m*ln(n) to make the value of U_part not too small, so as to ensure the stability and better effect of the model.

Reference: About the U_part · Issue #161 · zhouhaoyi/Informer2020 · GitHub

Finally, you can see that originally our implementation was like you described, and than fixed :slight_smile::

Thank you for noticing this! I added this discusion to the blogpost for the next readers.

Eli