flashinfer.mla.xqa_batch_decode_with_kv_cache_mla

flashinfer.mla.xqa_batch_decode_with_kv_cache_mla(query: Tensor, kv_cache: Tensor, workspace_buffer: Tensor, qk_nope_head_dim: int, kv_lora_rank: int, qk_rope_head_dim: int, block_tables: Tensor, seq_lens: Tensor, max_seq_len: int, out: Tensor | None = None, bmm1_scale: float | Tensor = 1.0, bmm2_scale: float | Tensor = 1.0, sinks: List[Tensor] | None = None, enable_pdl: bool | None = None) Tensor

XQA-backend batched MLA decode.

Single-query (MTP-aware) MLA decode kernel optimized for SM120a / SM121a tensor cores. Accepts the concatenated (q_nope || q_rope) query and (ckv || kpe) paged KV cache layout used by DeepSeek-V3 / R1 inference.

Parameters:
  • query (torch.Tensor) – Query tensor with shape [batch_size, q_len_per_request, num_heads, head_dim_qk] where head_dim_qk = kv_lora_rank + qk_rope_head_dim. Must be the concatenation [q_nope, q_rope]. q_len_per_request is the MTP query length and is currently required to be 1.

  • kv_cache (torch.Tensor) – Paged KV cache, either 3-D [num_pages, page_size, kv_lora_rank + qk_rope_head_dim] or 4-D [num_pages, 1, page_size, kv_lora_rank + qk_rope_head_dim]. The last dimension is the concatenation [ckv_cache, kpe_cache]. Both shapes are accepted for backward compatibility.

  • workspace_buffer (torch.Tensor) – Pre-allocated workspace buffer. Must be zero-initialized on first use.

  • qk_nope_head_dim (int) – Non-RoPE head dimension. Must be 128. Will be removed in 1.0; pass kv_lora_rank instead going forward.

  • kv_lora_rank (int) – Rank of the latent KV projection. Must be 512.

  • qk_rope_head_dim (int) – RoPE head dimension appended to the latent projection. Must be 64.

  • block_tables (torch.Tensor) – Per-request paged KV block table, shape [batch_size, num_pages].

  • seq_lens (torch.Tensor) – Per-request KV sequence length, shape [batch_size].

  • max_seq_len (int) – Maximum KV sequence length used for kernel scheduling. Will be removed in 1.0; the kernel reads the per-request lengths from seq_lens.

  • out (Optional[torch.Tensor]) – Optional output tensor of shape [batch_size, num_heads, kv_lora_rank] and dtype torch.bfloat16. If None, it is allocated internally.

  • bmm1_scale (Union[float, torch.Tensor]) – Fused scale for MLA BMM1 (see Note). float for static (CUDA-graph safe) scales; torch.Tensor for on-device dynamic scales (FP8 only).

  • bmm2_scale (Union[float, torch.Tensor]) – Fused scale for MLA BMM2 (see Note). Same typing rules as bmm1_scale.

  • sinks (Optional[List[torch.Tensor]]) – Attention-sink tensors. Currently unsupported and must be None.

  • enable_pdl (Optional[bool]) – Programmatic Dependent Launch toggle. When None, auto-detects support from the device.

Returns:

Attention output, shape [batch_size, num_heads, kv_lora_rank], dtype torch.bfloat16.

Return type:

torch.Tensor

Note

In MLA, the BMM1 and BMM2 scales are fused as:

bmm1_scale = q_scale * k_scale * sm_scale / sqrt(head_dim_qk)
bmm2_scale = v_scale * o_scale

The scale factors must be static constants for CUDA graph capture. Either the (bmm1_scale, bmm2_scale) (float) pair or the on-device (bmm1_scale_log2_tensor, bmm2_scale_tensor) tensor pair may be passed. When tensor inputs are supplied, the on-device path is taken (FP8 only).