flashinfer.mla.xqa_batch_decode_with_kv_cache_mla¶
- flashinfer.mla.xqa_batch_decode_with_kv_cache_mla(query: Tensor, kv_cache: Tensor, workspace_buffer: Tensor, qk_nope_head_dim: int, kv_lora_rank: int, qk_rope_head_dim: int, block_tables: Tensor, seq_lens: Tensor, max_seq_len: int, out: Tensor | None = None, bmm1_scale: float | Tensor = 1.0, bmm2_scale: float | Tensor = 1.0, sinks: List[Tensor] | None = None, enable_pdl: bool | None = None) Tensor¶
XQA-backend batched MLA decode.
Single-query (MTP-aware) MLA decode kernel optimized for SM120a / SM121a tensor cores. Accepts the concatenated
(q_nope || q_rope)query and(ckv || kpe)paged KV cache layout used by DeepSeek-V3 / R1 inference.- Parameters:
query (torch.Tensor) – Query tensor with shape
[batch_size, q_len_per_request, num_heads, head_dim_qk]wherehead_dim_qk = kv_lora_rank + qk_rope_head_dim. Must be the concatenation[q_nope, q_rope].q_len_per_requestis the MTP query length and is currently required to be1.kv_cache (torch.Tensor) – Paged KV cache, either 3-D
[num_pages, page_size, kv_lora_rank + qk_rope_head_dim]or 4-D[num_pages, 1, page_size, kv_lora_rank + qk_rope_head_dim]. The last dimension is the concatenation[ckv_cache, kpe_cache]. Both shapes are accepted for backward compatibility.workspace_buffer (torch.Tensor) – Pre-allocated workspace buffer. Must be zero-initialized on first use.
qk_nope_head_dim (int) – Non-RoPE head dimension. Must be
128. Will be removed in 1.0; passkv_lora_rankinstead going forward.kv_lora_rank (int) – Rank of the latent KV projection. Must be
512.qk_rope_head_dim (int) – RoPE head dimension appended to the latent projection. Must be
64.block_tables (torch.Tensor) – Per-request paged KV block table, shape
[batch_size, num_pages].seq_lens (torch.Tensor) – Per-request KV sequence length, shape
[batch_size].max_seq_len (int) – Maximum KV sequence length used for kernel scheduling. Will be removed in 1.0; the kernel reads the per-request lengths from
seq_lens.out (Optional[torch.Tensor]) – Optional output tensor of shape
[batch_size, num_heads, kv_lora_rank]and dtypetorch.bfloat16. IfNone, it is allocated internally.bmm1_scale (Union[float, torch.Tensor]) – Fused scale for MLA BMM1 (see Note).
floatfor static (CUDA-graph safe) scales;torch.Tensorfor on-device dynamic scales (FP8 only).bmm2_scale (Union[float, torch.Tensor]) – Fused scale for MLA BMM2 (see Note). Same typing rules as
bmm1_scale.sinks (Optional[List[torch.Tensor]]) – Attention-sink tensors. Currently unsupported and must be
None.enable_pdl (Optional[bool]) – Programmatic Dependent Launch toggle. When
None, auto-detects support from the device.
- Returns:
Attention output, shape
[batch_size, num_heads, kv_lora_rank], dtypetorch.bfloat16.- Return type:
torch.Tensor
Note
In MLA, the BMM1 and BMM2 scales are fused as:
bmm1_scale = q_scale * k_scale * sm_scale / sqrt(head_dim_qk) bmm2_scale = v_scale * o_scale
The scale factors must be static constants for CUDA graph capture. Either the
(bmm1_scale, bmm2_scale)(float) pair or the on-device(bmm1_scale_log2_tensor, bmm2_scale_tensor)tensor pair may be passed. When tensor inputs are supplied, the on-device path is taken (FP8 only).