xLSTM mode="train": shape mismatch / runtime errors with chunk_size on short sequences

Greetings,

I’m trying to use xLSTM in “train” mode because I need to train on truncated sequences (VRAM constraints). Inference works, but training consistently crashes with shape errors (or sometimes TypeError: ‘Tensor’ object is not callable). I can’t find a configuration where the shapes line up.

This is my minimal test code:

from transformers import xLSTMConfig, xLSTMModel
import torch

D = 128      # embedding_dim = hidden_size
S = 6        # Sequence length
H = 8        # num_heads

cfg = xLSTMConfig(
    embedding_dim=D,
    hidden_size=D,
    num_hidden_layers=3,
    num_heads=H,
    # --- Training ---
    mode="train",           # "train" or "train_with_padding" ?
    use_cache=False,
    return_last_states=False,
    # --- Chunking ---
    chunk_size=3,           # 6 % 3 == 0
    # --- Head-Dims ---
    qk_dim_factor=0.5,      # dqk = D * 0.5 / H = 8
    v_dim_factor=1.0,       # dv  = D * 1.0 / H = 16
    # --- Kernels
    # chunkwise_kernel="chunkwise--native_autograd",
    # sequence_kernel="native_sequence__native",
)

model = xLSTMModel(cfg).cuda()
x = torch.randn(1, S, D, device="cuda")
out = model(inputs_embeds=x, use_cache=False)
print(out.last_hidden_state.shape)

This is throwing the error:

RuntimeError: shape ‘[1, 8, 2, 3, 16]’ is invalid for input of size 384

This looks like the model tries to view into (batch=1, heads=8, num_chunks=2, chunk_size=3, dqk=16).
That target shape totals 768 elements, while the available tensor has 384 elements (which equals 3 * 128). My two observations:

With D=128 and qk_dim_factor=0.5, I’d expect per‑head dqk = 8, not 16. So I may be misunderstanding how qk_dim_factor is applied in train mode, or the code is using a different value internally.

The mismatch 768 vs. 384 persists even if I change qk_dim_factor, v_dim_factor, or chunk_size to other sensible values.

What I tried:

mode=“train” with chunk_size ∈ {6, 3, 2}. Results:

chunk_size = 7 (for S=6) → ValueError: Sequence length 6 is not divisible by chunk size 7.

chunk_size = 6 → sometimes TypeError: ‘Tensor’ object is not callable inside a vecM_* combine path.

chunk_size = 3 or 2 → shape/view errors like the one above.

set to “train_with_padding” mode with similar settings (no cache) → still seeing shape/view errors for short sequences.

Ensured use_cache=False during training and did not pass any cache (I only use cache for inference/streaming).

Verified basic divisibility constraints (e.g., hidden_size % num_heads == 0).

Tried different qk_dim_factor/v_dim_factor combinations (e.g., 1.0 / 1.0, 0.5 / 0.5, etc.). The failing target shape often remains [1, 8, 2, 3, 16] vs. a source size of 384.

Expected vs. actual

Expected: In mode=“train” with use_cache=False, the call returns last_hidden_state of shape (1, 6, 128) without shape/view errors, similar to inference (where I can set chunk_size >= S (S=6) and it runs).

Actual: Shape/view errors as above, or occasionally TypeError: ‘Tensor’ object is not callable from internal combine functions when chunk_size equals S.

My Questions

Exact constraints for train mode?
Besides sequence_length % chunk_size == 0, are there additional constraints between hidden_size, num_heads, qk_dim_factor, and v_dim_factor in the training kernels? For example, should
(hidden_size * qk_dim_factor) % num_heads == 0 and (hidden_size * v_dim_factor) % num_heads == 0
always hold (interpreting factors as totals, not per‑head)?

Interpretation of qk_dim_factor:
In my example, I expected per‑head dqk = (hidden_size * qk_dim_factor) / num_heads = 8, but the failing view suggests dqk = 16. Is qk_dim_factor applied differently in train mode or in a specific kernel path?

Known issues with short sequences / small chunk_size?
Are the shape/view/TypeError symptoms known for short sequences (e.g., S=6) in mode=“train”?

Recommended configuration for training on truncated sequences (no cache):
Is mode=“train_with_padding” the intended path, and if so, should I provide an attention_mask and mask out the padding in the loss, or is the padding handled fully internally?

Any guidance on the correct xLSTMConfig for “train” mode (for short sequences) or pointers to a working example would be greatly appreciated. Thanks!

1 Like

The conditions for it to work are extremely strict, but after downgrading to PyTorch 2.5.1, I barely managed to get it working on Colab Free…

Thank you for your reply.
Unfortunately, I was unable to get the example to run on my Windows computer.
I will try testing it later with Triton on a Linux computer.
Thank you very much for the example.

1 Like