flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla

flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla(query: Tensor, kv_cache: Tensor, workspace_buffer: Tensor, qk_nope_head_dim: int, kv_lora_rank: int, qk_rope_head_dim: int, block_tables: Tensor, seq_lens: Tensor | None, max_seq_len: int, sparse_mla_top_k: int = 0, out: Tensor | None = None, bmm1_scale: float | Tensor = 1.0, bmm2_scale: float | Tensor = 1.0, sinks: List[Tensor] | None = None, skip_softmax_threshold_scale_factor: float | None = None, enable_pdl: bool | None = None, backend: str = 'auto', is_var_seq: bool = True, uses_shared_paged_kv_idx: bool = True, lse: Tensor | None = None, return_lse: bool = False, cute_dsl_impl: str = 'auto', kv_scale_format: str = 'auto', cum_seq_lens_q: Tensor | None = None, max_q_len: int | None = None) Tensor | Tuple[Tensor, Tensor]

Decode MLA with TRTLLM-GEN, CuteDSL, XQA, or SM120/SM121 sparse kernels.

With backend="auto", SM100/SM103 devices use TRTLLM-GEN for sparse MLA when sparse_mla_top_k > 0. SM120/SM121 devices use the packed sparse backend for sparse_mla_top_k > 0 and XQA for dense decode.

Parameters:
  • query (torch.Tensor) – Query tensor with shape [batch_size, q_len_per_request, num_heads, head_dim_qk] where head_dim_qk = kv_lora_rank + qk_rope_head_dim. For the SM120/SM121 v32/GLM sparse backend, this must be BF16 with head_dim_qk == 576.

  • kv_cache (torch.Tensor) – For TRTLLM-GEN, CuteDSL, and XQA, the paged KV cache is [num_pages, page_size, kv_lora_rank + qk_rope_head_dim] or [num_pages, 1, page_size, kv_lora_rank + qk_rope_head_dim] and uses the query-compatible dense dtype. For the SM120/SM121 v32/GLM sparse backend, this is a packed uint8 cache with 656 bytes per token, shaped [num_pages, page_size, 656] or [num_pages, 1, page_size, 656].

  • workspace_buffer (torch.Tensor) – Pre-allocated workspace buffer. Must be zero-initialized on first use by kernels that use semaphore state.

  • qk_nope_head_dim (int) – Non-RoPE query dimension. Dense MLA paths commonly use 128 or 64 depending on model. The SM120/SM121 sparse v32/GLM backend ignores this value and validates query.shape[-1] == 576 instead.

  • kv_lora_rank (int) – Latent KV rank. TRTLLM-GEN and SM120/SM121 sparse v32/GLM use 512.

  • qk_rope_head_dim (int) – RoPE head dimension. Sparse MLA paths use 64.

  • block_tables (torch.Tensor) – Page table for dense MLA backends when sparse_mla_top_k == 0. For SM100/SM103 TRTLLM-GEN sparse MLA it is the usual paged block table. When cum_seq_lens_q is provided with sparse MLA, pass compact sparse rows in flattened query-token order with shape [total_q, sparse_mla_top_k]. For SM120/SM121 sparse v32/GLM, it is the sparse index matrix and must have shape [batch_size, q_len_per_request, sparse_mla_top_k] with int32 physical token indices.

  • seq_lens (Optional[torch.Tensor]) – Per-request KV sequence lengths for dense and TRTLLM-GEN paths. For SM120/SM121 sparse v32/GLM, pass [batch_size, q_len_per_request] or flattened [batch_size * q_len_per_request] active top-k lengths; if None, every column in block_tables is active.

  • max_seq_len (int) – Maximum KV sequence length used for dense/TRTLLM-GEN scheduling. Ignored by the SM120/SM121 sparse v32/GLM backend.

  • sparse_mla_top_k (int) – Enables sparse MLA when greater than zero. On SM100/SM103 this selects the TRTLLM-GEN sparse page-table path. On SM120/SM121 with backend="auto" or backend="sparse", this is the width of the packed v32/GLM sparse index matrix. The TRTLLM-GEN backend supports dense query input or flattened query input plus cum_seq_lens_q.

  • out (Optional[torch.Tensor]) – Output tensor. If not provided, it is allocated internally.

  • bmm1_scale (Union[float, torch.Tensor]) – Fused scale for MLA BMM1. TRTLLM-GEN accepts a FP32 tensor or float. CuteDSL, XQA, and SM120/SM121 sparse v32/GLM require a float.

  • bmm2_scale (Union[float, torch.Tensor]) – Fused scale for MLA BMM2. TRTLLM-GEN accepts a FP32 tensor or float. CuteDSL and XQA require a float. SM120/SM121 sparse v32/GLM requires 1.0.

  • sinks (Optional[List[torch.Tensor]]) – Additional value per head in the denominator of the softmax. Supported by trtllm-gen, cute-dsl, and sparse. On cute-dsl this requires the modular implementation; cute_dsl_impl="auto" (the default) promotes to modular automatically, and cute_dsl_impl="monolithic" with sinks set raises ValueError.

  • skip_softmax_threshold_scale_factor (threshold scale factor for skipping softmax operations.) – Providing a value for this parameter enables skip-softmax sparsity as described in: https://arxiv.org/abs/2512.12087 If no value is provided, then standard attention is used. Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation. The actual threshold value equals the provided threshold_scale_factor divided by the context length.

  • enable_pdl (Optional[bool]) – Programmatic Dependent Launch toggle. When None (default), auto-detects support from the query device. Honoured by the trtllm-gen and xqa backends; ignored by cute-dsl.

  • backend (str = "auto") – Implementation backend. Valid values are "auto", "xqa", "trtllm-gen", "cute-dsl", and "sparse". "auto" chooses "trtllm-gen" for SM100/SM103 sparse MLA and chooses "sparse" for SM120/SM121 when sparse_mla_top_k > 0; otherwise SM120/SM121 dense decode uses "xqa". The cute-dsl backend has two interchangeable implementations (monolithic and modular) on the same shape/dtype envelope; which one runs is controlled by the cute_dsl_impl kwarg below.

  • is_var_seq (bool) – Whether the sequence length is variable. If True, the sequence length is variable. Otherwise,the sequence length is fixed for all the requests in the batch.

  • uses_shared_paged_kv_idx (bool = True) – Whether K and V page indices are shared as a unified index. True (default) uses vLLM/FlashInfer layout with a 2D page table. False uses TRT-LLM layout with a 3D page table [batch_size, 2, max_num_pages_per_seq]. False is only supported by TRTLLM-GEN.

  • lse (Optional[torch.Tensor] = None) –

    Optional pre-allocated buffer for Log-Sum-Exp values. Supported by trtllm-gen, cute-dsl, and sparse backends. Must have dtype torch.float32. Accepted shapes:

    • [batch_size * q_len_per_request, num_qo_heads] (TRTLLM-GEN native; accepted by sparse), or

    • [batch_size, q_len_per_request, num_qo_heads] (cute-dsl native; also accepted by cute-dsl).

    If return_lse is True and this is None, a buffer will be allocated by the backend.

  • return_lse (bool = False) – Whether to return LSE values. Supported by trtllm-gen, cute-dsl, and sparse backends. When True, the function returns (out, lse).

  • cute_dsl_impl (str = "auto") –

    Which cute-dsl implementation to use. Honored when backend="cute-dsl" and when backend="auto" considers the cute-dsl candidate; ignored for non-cute-dsl backends.

    • "auto" (default) — picks monolithic by default, automatically promoted to modular when the call uses a feature monolithic doesn’t support (currently sinks).

    • "modular" — strict. Always run the modular kernels.

    • "monolithic" — strict. Always run the monolithic kernels; raise ValueError if the call uses any modular-only feature (e.g. sinks).

  • kv_scale_format (str = "auto") – Scale semantics for the SM120/SM121 packed v32/GLM sparse backend. "auto" and "pow2_fp32" select DSv3.2 power-of-2 FP32 inline scales; "arbitrary_fp32" selects GLM-style arbitrary FP32 inline scales. Ignored by the trtllm-gen, xqa, and cute-dsl backends.

  • cum_seq_lens_q (Optional[torch.Tensor] = None) – Cumulative query sequence lengths for variable-length query support, shape [batch_size + 1], dtype torch.int32. Must be a 1D tensor with at least two entries. When max_q_len is not provided, this function validates that it starts with 0, ends at query.size(0), and is monotonically non-decreasing. Only supported by the trtllm-gen backend. When provided, query must have shape [total_q, num_heads, head_dim_qk]. For best performance, provide max_q_len together with cum_seq_lens_q to avoid host-side metadata validation.

  • max_q_len (Optional[int] = None) – Maximum query sequence length across all requests when using cum_seq_lens_q. Provide with cum_seq_lens_q to avoid host-side metadata validation. Must be greater than or equal to the maximum segment length represented by cum_seq_lens_q. Over-estimation is safe but may waste work; under-estimation is invalid and may produce incorrect output.

Note

In MLA, the actual BMM1 and BMM2 scales applied would be fused as: bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) bmm2_scale = v_scale * o_scale or, bmm1_scale = torch.Tensor([q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)) bmm2_scale = torch.Tensor([v_scale * o_scale])

The two scale factors should be static constant for cuda graph capture. Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided.

For static constant scale factors, the scale factors should be provided as float.
  • (bmm1_scale, bmm2_scale)

For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor.
  • (bmm1_scale_log2_tensor, bmm2_scale_tensor)

  • Currently, only fp8 tensor core operation supports this mode.

When both are provided, the dynamic scale factor tensors will be used.

On SM100/SM103 dense MLA, calling under flashinfer.autotune(True) with backend="auto" profiles both trtllm-gen and cute-dsl across a bucketed batch sweep up to each runner’s kernel/workspace cap and caches the winning runner per shape signature. Subsequent calls under autotune(False) dispatch to the cached choice; any batch outside the tuned range falls back to a default runner with a one-time warning.

The autotune bucket range and cache key do not depend on kv_cache.shape[0] (the number of pages in the pool), so reallocating the pool between tuning and inference does not invalidate cached choices. However, the page-aliasing ratio during profiling does depend on the pool size: synthetic block_tables are filled by uniform random sampling into [0, kv_cache.shape[0]), so a small pool produces high aliasing (L2-resident reads) and a large pool produces low aliasing (HBM-bound reads). For best profile fidelity, autotune with a kv_cache whose size reflects the production page-sharing pattern of your workload (e.g., heavily shared prefix → smaller pool; independent contexts → larger pool).