flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla

flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla(query: Tensor, kv_cache: Tensor, workspace_buffer: Tensor, qk_nope_head_dim: int, kv_lora_rank: int, qk_rope_head_dim: int, block_tables: Tensor, seq_lens: Tensor, max_seq_len: int, sparse_mla_top_k: int = 0, out: Tensor | None = None, bmm1_scale: float | Tensor = 1.0, bmm2_scale: float | Tensor = 1.0, sinks: List[Tensor] | None = None, enable_pdl: bool = None, backend: str = 'auto') Tensor
Parameters:
  • query ([batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.)

  • kv_cache ([num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache)

  • workspace_buffer ([num_semaphores, 4], used for multi_block mode. Must be initialized to 0 for its first use.)

  • qk_nope_head_dim (qk_nope_head_dim, must be 128)

  • kv_lora_rank (kv_lora_rank, must be 512)

  • qk_rope_head_dim (qk_rope_head_dim, must be 64)

  • sparse_mla_top_k (sparse MLA top k, must be 0 for non-sparse MLA.)

  • block_tables (page_table of kv cache, [batch_size, num_pages])

  • seq_lens (query_len)

  • max_seq_len (max sequence length for kv_cache)

  • out (output tensor, if not provided, will be allocated internally)

  • bmm1_scale (fused scale for mla bmm1 input.) – when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.

  • bmm2_scale (fused scale for mla bmm2 input.) – when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.

  • sinks (additional value per head in the denominator of the softmax.)

  • backend (str = "auto") – The implementation backend, could be auto/xqa or trtllm-gen. Defaults to auto. When set to auto, the backend will be chosen based on the device architecture and kernel availability. For sm_100 and sm_103 (blackwell architecture), auto will choose trtllm-gen backend. For sm_120 (blackwell architecture), auto will choose xqa backend.

Note

In MLA, the actual BMM1 and BMM2 scales applied would be fused as: bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) bmm2_scale = v_scale * o_scale or, bmm1_scale = torch.Tensor([q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)) bmm2_scale = torch.Tensor([v_scale * o_scale])

The two scale factors should be static constant for cuda graph capture. Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided.

For static constant scale factors, the scale factors should be provided as float.
  • (bmm1_scale, bmm2_scale)

For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor.
  • (bmm1_scale_log2_tensor, bmm2_scale_tensor)

  • Currently, only fp8 tensor core operation supports this mode.

When both are provided, the dynamic scale factor tensors will be used.