flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla¶
- flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla(query: Tensor, kv_cache: Tensor, workspace_buffer: Tensor, qk_nope_head_dim: int, kv_lora_rank: int, qk_rope_head_dim: int, block_tables: Tensor, seq_lens: Tensor, max_seq_len: int, sparse_mla_top_k: int = 0, out: Tensor | None = None, bmm1_scale: float | Tensor = 1.0, bmm2_scale: float | Tensor = 1.0, sinks: List[Tensor] | None = None, enable_pdl: bool = None, backend: str = 'auto') Tensor¶
- Parameters:
query ([batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.)
kv_cache ([num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache)
workspace_buffer ([num_semaphores, 4], used for multi_block mode. Must be initialized to 0 for its first use.)
qk_nope_head_dim (qk_nope_head_dim, must be 128)
kv_lora_rank (kv_lora_rank, must be 512)
qk_rope_head_dim (qk_rope_head_dim, must be 64)
sparse_mla_top_k (sparse MLA top k, must be 0 for non-sparse MLA.)
block_tables (page_table of kv cache, [batch_size, num_pages])
seq_lens (query_len)
max_seq_len (max sequence length for kv_cache)
out (output tensor, if not provided, will be allocated internally)
bmm1_scale (fused scale for mla bmm1 input.) – when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.
bmm2_scale (fused scale for mla bmm2 input.) – when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.
sinks (additional value per head in the denominator of the softmax.)
backend (str = "auto") – The implementation backend, could be
auto/xqaortrtllm-gen. Defaults toauto. When set toauto, the backend will be chosen based on the device architecture and kernel availability. For sm_100 and sm_103 (blackwell architecture),autowill choosetrtllm-genbackend. For sm_120 (blackwell architecture),autowill choosexqabackend.
Note
In MLA, the actual BMM1 and BMM2 scales applied would be fused as: bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) bmm2_scale = v_scale * o_scale or, bmm1_scale = torch.Tensor([q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)) bmm2_scale = torch.Tensor([v_scale * o_scale])
The two scale factors should be static constant for cuda graph capture. Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided.
- For static constant scale factors, the scale factors should be provided as float.
(bmm1_scale, bmm2_scale)
- For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor.
(bmm1_scale_log2_tensor, bmm2_scale_tensor)
Currently, only fp8 tensor core operation supports this mode.
When both are provided, the dynamic scale factor tensors will be used.