flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla¶
- flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla(query: Tensor, kv_cache: Tensor, workspace_buffer: Tensor, qk_nope_head_dim: int, kv_lora_rank: int, qk_rope_head_dim: int, block_tables: Tensor, seq_lens: Tensor | None, max_seq_len: int, sparse_mla_top_k: int = 0, out: Tensor | None = None, bmm1_scale: float | Tensor = 1.0, bmm2_scale: float | Tensor = 1.0, sinks: List[Tensor] | None = None, skip_softmax_threshold_scale_factor: float | None = None, enable_pdl: bool | None = None, backend: str = 'auto', is_var_seq: bool = True, uses_shared_paged_kv_idx: bool = True, lse: Tensor | None = None, return_lse: bool = False, cute_dsl_impl: str = 'auto', kv_scale_format: str = 'auto', cum_seq_lens_q: Tensor | None = None, max_q_len: int | None = None) Tensor | Tuple[Tensor, Tensor]¶
Decode MLA with TRTLLM-GEN, CuteDSL, XQA, or SM120/SM121 sparse kernels.
With
backend="auto", SM100/SM103 devices use TRTLLM-GEN for sparse MLA whensparse_mla_top_k > 0. SM120/SM121 devices use the packed sparse backend forsparse_mla_top_k > 0and XQA for dense decode.- Parameters:
query (torch.Tensor) – Query tensor with shape
[batch_size, q_len_per_request, num_heads, head_dim_qk]wherehead_dim_qk = kv_lora_rank + qk_rope_head_dim. For the SM120/SM121 v32/GLM sparse backend, this must be BF16 withhead_dim_qk == 576.kv_cache (torch.Tensor) – For TRTLLM-GEN, CuteDSL, and XQA, the paged KV cache is
[num_pages, page_size, kv_lora_rank + qk_rope_head_dim]or[num_pages, 1, page_size, kv_lora_rank + qk_rope_head_dim]and uses the query-compatible dense dtype. For the SM120/SM121 v32/GLM sparse backend, this is a packed uint8 cache with 656 bytes per token, shaped[num_pages, page_size, 656]or[num_pages, 1, page_size, 656].workspace_buffer (torch.Tensor) – Pre-allocated workspace buffer. Must be zero-initialized on first use by kernels that use semaphore state.
qk_nope_head_dim (int) – Non-RoPE query dimension. Dense MLA paths commonly use
128or64depending on model. The SM120/SM121 sparse v32/GLM backend ignores this value and validatesquery.shape[-1] == 576instead.kv_lora_rank (int) – Latent KV rank. TRTLLM-GEN and SM120/SM121 sparse v32/GLM use
512.qk_rope_head_dim (int) – RoPE head dimension. Sparse MLA paths use
64.block_tables (torch.Tensor) – Page table for dense MLA backends when
sparse_mla_top_k == 0. For SM100/SM103 TRTLLM-GEN sparse MLA it is the usual paged block table. Whencum_seq_lens_qis provided with sparse MLA, pass compact sparse rows in flattened query-token order with shape[total_q, sparse_mla_top_k]. For SM120/SM121 sparse v32/GLM, it is the sparse index matrix and must have shape[batch_size, q_len_per_request, sparse_mla_top_k]with int32 physical token indices.seq_lens (Optional[torch.Tensor]) – Per-request KV sequence lengths for dense and TRTLLM-GEN paths. For SM120/SM121 sparse v32/GLM, pass
[batch_size, q_len_per_request]or flattened[batch_size * q_len_per_request]active top-k lengths; ifNone, every column inblock_tablesis active.max_seq_len (int) – Maximum KV sequence length used for dense/TRTLLM-GEN scheduling. Ignored by the SM120/SM121 sparse v32/GLM backend.
sparse_mla_top_k (int) – Enables sparse MLA when greater than zero. On SM100/SM103 this selects the TRTLLM-GEN sparse page-table path. On SM120/SM121 with
backend="auto"orbackend="sparse", this is the width of the packed v32/GLM sparse index matrix. The TRTLLM-GEN backend supports dense query input or flattened query input pluscum_seq_lens_q.out (Optional[torch.Tensor]) – Output tensor. If not provided, it is allocated internally.
bmm1_scale (Union[float, torch.Tensor]) – Fused scale for MLA BMM1. TRTLLM-GEN accepts a FP32 tensor or float. CuteDSL, XQA, and SM120/SM121 sparse v32/GLM require a float.
bmm2_scale (Union[float, torch.Tensor]) – Fused scale for MLA BMM2. TRTLLM-GEN accepts a FP32 tensor or float. CuteDSL and XQA require a float. SM120/SM121 sparse v32/GLM requires
1.0.sinks (Optional[List[torch.Tensor]]) – Additional value per head in the denominator of the softmax. Supported by
trtllm-gen,cute-dsl, andsparse. Oncute-dslthis requires the modular implementation;cute_dsl_impl="auto"(the default) promotes to modular automatically, andcute_dsl_impl="monolithic"with sinks set raisesValueError.skip_softmax_threshold_scale_factor (threshold scale factor for skipping softmax operations.) – Providing a value for this parameter enables skip-softmax sparsity as described in: https://arxiv.org/abs/2512.12087 If no value is provided, then standard attention is used. Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation. The actual threshold value equals the provided threshold_scale_factor divided by the context length.
enable_pdl (Optional[bool]) – Programmatic Dependent Launch toggle. When
None(default), auto-detects support from the query device. Honoured by thetrtllm-genandxqabackends; ignored bycute-dsl.backend (str = "auto") – Implementation backend. Valid values are
"auto","xqa","trtllm-gen","cute-dsl", and"sparse"."auto"chooses"trtllm-gen"for SM100/SM103 sparse MLA and chooses"sparse"for SM120/SM121 whensparse_mla_top_k > 0; otherwise SM120/SM121 dense decode uses"xqa". Thecute-dslbackend has two interchangeable implementations (monolithicandmodular) on the same shape/dtype envelope; which one runs is controlled by thecute_dsl_implkwarg below.is_var_seq (bool) – Whether the sequence length is variable. If True, the sequence length is variable. Otherwise,the sequence length is fixed for all the requests in the batch.
uses_shared_paged_kv_idx (bool = True) – Whether K and V page indices are shared as a unified index. True (default) uses vLLM/FlashInfer layout with a 2D page table. False uses TRT-LLM layout with a 3D page table
[batch_size, 2, max_num_pages_per_seq]. False is only supported by TRTLLM-GEN.lse (Optional[torch.Tensor] = None) –
Optional pre-allocated buffer for Log-Sum-Exp values. Supported by
trtllm-gen,cute-dsl, andsparsebackends. Must have dtypetorch.float32. Accepted shapes:[batch_size * q_len_per_request, num_qo_heads](TRTLLM-GEN native; accepted by sparse), or[batch_size, q_len_per_request, num_qo_heads](cute-dsl native; also accepted by cute-dsl).
If
return_lseis True and this is None, a buffer will be allocated by the backend.return_lse (bool = False) – Whether to return LSE values. Supported by
trtllm-gen,cute-dsl, andsparsebackends. When True, the function returns(out, lse).cute_dsl_impl (str = "auto") –
Which cute-dsl implementation to use. Honored when
backend="cute-dsl"and whenbackend="auto"considers the cute-dsl candidate; ignored for non-cute-dsl backends."auto"(default) — picks monolithic by default, automatically promoted to modular when the call uses a feature monolithic doesn’t support (currentlysinks)."modular"— strict. Always run the modular kernels."monolithic"— strict. Always run the monolithic kernels; raiseValueErrorif the call uses any modular-only feature (e.g.sinks).
kv_scale_format (str = "auto") – Scale semantics for the SM120/SM121 packed v32/GLM sparse backend.
"auto"and"pow2_fp32"select DSv3.2 power-of-2 FP32 inline scales;"arbitrary_fp32"selects GLM-style arbitrary FP32 inline scales. Ignored by thetrtllm-gen,xqa, andcute-dslbackends.cum_seq_lens_q (Optional[torch.Tensor] = None) – Cumulative query sequence lengths for variable-length query support, shape
[batch_size + 1], dtypetorch.int32. Must be a 1D tensor with at least two entries. Whenmax_q_lenis not provided, this function validates that it starts with 0, ends atquery.size(0), and is monotonically non-decreasing. Only supported by thetrtllm-genbackend. When provided,querymust have shape[total_q, num_heads, head_dim_qk]. For best performance, providemax_q_lentogether withcum_seq_lens_qto avoid host-side metadata validation.max_q_len (Optional[int] = None) – Maximum query sequence length across all requests when using
cum_seq_lens_q. Provide withcum_seq_lens_qto avoid host-side metadata validation. Must be greater than or equal to the maximum segment length represented bycum_seq_lens_q. Over-estimation is safe but may waste work; under-estimation is invalid and may produce incorrect output.
Note
In MLA, the actual BMM1 and BMM2 scales applied would be fused as: bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) bmm2_scale = v_scale * o_scale or, bmm1_scale = torch.Tensor([q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)) bmm2_scale = torch.Tensor([v_scale * o_scale])
The two scale factors should be static constant for cuda graph capture. Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided.
- For static constant scale factors, the scale factors should be provided as float.
(bmm1_scale, bmm2_scale)
- For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor.
(bmm1_scale_log2_tensor, bmm2_scale_tensor)
Currently, only fp8 tensor core operation supports this mode.
When both are provided, the dynamic scale factor tensors will be used.
On SM100/SM103 dense MLA, calling under
flashinfer.autotune(True)withbackend="auto"profiles bothtrtllm-genandcute-dslacross a bucketed batch sweep up to each runner’s kernel/workspace cap and caches the winning runner per shape signature. Subsequent calls underautotune(False)dispatch to the cached choice; any batch outside the tuned range falls back to a default runner with a one-time warning.The autotune bucket range and cache key do not depend on
kv_cache.shape[0](the number of pages in the pool), so reallocating the pool between tuning and inference does not invalidate cached choices. However, the page-aliasing ratio during profiling does depend on the pool size: syntheticblock_tablesare filled by uniform random sampling into[0, kv_cache.shape[0]), so a small pool produces high aliasing (L2-resident reads) and a large pool produces low aliasing (HBM-bound reads). For best profile fidelity, autotune with akv_cachewhose size reflects the production page-sharing pattern of your workload (e.g., heavily shared prefix → smaller pool; independent contexts → larger pool).