flashinfer.mla

MLA (Multi-head Latent Attention) is an attention mechanism proposed in DeepSeek series of models ( DeepSeek-V2, DeepSeek-V3, and DeepSeek-R1).

PageAttention for MLA

class flashinfer.mla.BatchMLAPagedAttentionWrapper(float_workspace_buffer: torch.Tensor, use_cuda_graph: bool = False, qo_indptr: torch.Tensor | None = None, kv_indptr: torch.Tensor | None = None, kv_indices: torch.Tensor | None = None, kv_len_arr: torch.Tensor | None = None, backend: str = 'fa2')

Wrapper class for MLA (Multi-head Latent Attention) PagedAttention on DeepSeek models. This kernel can be used in decode, and incremental prefill and should be used together with Matrix Absorption trick: where \(W_{UQ}\) is absorbed with \(W_{UK}\), and \(W_{UV}\) is absorbed with \(W_{O}\). For MLA attention without Matrix Absorption (head_dim_qk=192 and head_dim_vo=128, which is used in prefilling self-attention stage), please use flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper.

More information about The Paged KV-Cache layout in MLA is explained in our tutorial MLA Page Layout.

For more details about the MLA computation, Matrix Absorption and FlashInfer’s MLA implementation, please refer to our blog post.

Example

>>> import torch
>>> import flashinfer
>>> num_local_heads = 128
>>> batch_size = 114
>>> head_dim_ckv = 512
>>> head_dim_kpe = 64
>>> page_size = 1
>>> mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
...     torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0),
...     backend="fa2"
... )
>>> q_indptr = torch.arange(0, batch_size + 1).to(0).int() # for decode, each query length is 1
>>> kv_lens = torch.full((batch_size,), 999, dtype=torch.int32).to(0)
>>> kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * 999
>>> kv_indices = torch.arange(0, batch_size * 999).to(0).int()
>>> q_nope = torch.randn(
...     batch_size * 1, num_local_heads, head_dim_ckv, dtype=torch.bfloat16, device="cuda"
... )
>>> q_pe = torch.zeros(
...     batch_size * 1, num_local_heads, head_dim_kpe, dtype=torch.bfloat16, device="cuda"
... )
>>> ckv = torch.randn(
...     batch_size * 999, 1, head_dim_ckv, dtype=torch.bfloat16, device="cuda"
... )
>>> kpe = torch.zeros(
...     batch_size * 999, 1, head_dim_kpe, dtype=torch.bfloat16, device="cuda"
... )
>>> sm_scale = 1.0 / ((128 + 64) ** 0.5)  # use head dimension before matrix absorption
>>> mla_wrapper.plan(
...     q_indptr,
...     kv_indptr,
...     kv_indices,
...     kv_lens,
...     num_local_heads,
...     head_dim_ckv,
...     head_dim_kpe,
...     page_size,
...     False,  # causal
...     sm_scale,
...     q_nope.dtype,
...     ckv.dtype,
... )
>>> o = mla_wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False)
>>> o.shape
torch.Size([114, 128, 512])
__init__(float_workspace_buffer: torch.Tensor, use_cuda_graph: bool = False, qo_indptr: torch.Tensor | None = None, kv_indptr: torch.Tensor | None = None, kv_indices: torch.Tensor | None = None, kv_len_arr: torch.Tensor | None = None, backend: str = 'fa2') None

Constructor for BatchMLAPagedAttentionWrapper.

Parameters:
  • float_workspace_buffer (torch.Tensor) – The user reserved workspace buffer used to store intermediate attention results in split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.

  • use_cuda_graph (bool, optional) – Whether to enable CUDA graph capture for the prefill kernels, if enabled, the auxiliary data structures will be stored in provided buffers. The batch_size cannot change during the lifecycle of this wrapper when CUDAGraph is enabled.

  • qo_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the qo_indptr array, the size of the buffer should be [batch_size + 1]. This argument is only effective when use_cuda_graph is True.

  • kv_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the kv_indptr array, the size of the buffer should be [batch_size + 1]. This argument is only effective when use_cuda_graph is True.

  • kv_indices_buf (Optional[torch.Tensor]) – The user reserved buffer to store the kv_indices array. This argument is only effective when use_cuda_graph is True.

  • kv_len_arr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the kv_len_arr array, the size of the buffer should be [batch_size]. This argument is only effective when use_cuda_graph is True.

  • backend (str) – The implementation backend, default is “fa2”.

plan(qo_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, kv_len_arr: torch.Tensor, num_heads: int, head_dim_ckv: int, head_dim_kpe: int, page_size: int, causal: bool, sm_scale: float, q_data_type: torch.dtype, kv_data_type: torch.dtype) None

Plan the MLA attention computation.

Parameters:
  • qo_indptr (torch.Tensor) – The indptr of the query/output tensor, shape: [batch_size + 1]. For decoding attention, the length of each query is 1, and the content of the tensor should be [0, 1, 2, ..., batch_size].

  • kv_indptr (torch.Tensor) – The indptr of the paged kv-cache, shape: [batch_size + 1].

  • kv_indices (torch.Tensor) – The page indices of the paged kv-cache, shape: [kv_indptr[-1]] or larger.

  • kv_len_arr (torch.Tensor) – The query length of each request, shape: [batch_size].

  • num_heads (int) – The number of heads in query/output tensor.

  • head_dim_ckv (int) – The head dimension of compressed-kv.

  • head_dim_kpe (int) – The head dimension for rope k-cache.

  • page_size (int) – The page size of the paged kv-cache.

  • causal (bool) – Whether to use causal attention.

  • sm_scale (float) – The scale factor for softmax operation.

  • q_data_type (torch.dtype) – The data type of the query tensor.

  • kv_data_type (torch.dtype) – The data type of the kv-cache tensor.

run(q_nope: torch.Tensor, q_pe: torch.Tensor, ckv_cache: torch.Tensor, kpe_cache: torch.Tensor, return_lse: Literal[False] = False) torch.Tensor
run(q_nope: torch.Tensor, q_pe: torch.Tensor, ckv_cache: torch.Tensor, kpe_cache: torch.Tensor, return_lse: Literal[True] = True) Tuple[torch.Tensor, torch.Tensor]

Run the MLA attention computation.

Parameters:
  • q_nope (torch.Tensor) – The query tensor without rope, shape: [batch_size, num_heads, head_dim_ckv].

  • q_pe (torch.Tensor) – The rope part of the query tensor, shape: [batch_size, num_heads, head_dim_kpe].

  • ckv_cache (torch.Tensor) – The compressed kv-cache tensor (without rope), shape: [num_pages, page_size, head_dim_ckv]. head_dim_ckv is 512 in DeepSeek v2/v3 models.

  • kpe_cache (torch.Tensor) – The rope part of the kv-cache tensor, shape: [num_pages, page_size, head_dim_kpe]. head_dim_kpe is 64 in DeepSeek v2/v3 models.

  • out (Optional[torch.Tensor]) – The output tensor, if not provided, will be allocated internally.

  • lse (Optional[torch.Tensor]) – The log-sum-exp of attention logits, if not provided, will be allocated internally.

  • return_lse (bool, optional) – Whether to return the log-sum-exp value, default is False.