flashinfer.mla¶
MLA (Multi-head Latent Attention) is an attention mechanism proposed in DeepSeek series of models ( DeepSeek-V2, DeepSeek-V3, and DeepSeek-R1).
PageAttention for MLA¶
- class flashinfer.mla.BatchMLAPagedAttentionWrapper(float_workspace_buffer: torch.Tensor, use_cuda_graph: bool = False, qo_indptr: torch.Tensor | None = None, kv_indptr: torch.Tensor | None = None, kv_indices: torch.Tensor | None = None, kv_len_arr: torch.Tensor | None = None, backend: str = 'fa2')¶
Wrapper class for MLA (Multi-head Latent Attention) PagedAttention on DeepSeek models. This kernel can be used in decode, and incremental prefill and should be used together with Matrix Absorption trick: where \(W_{UQ}\) is absorbed with \(W_{UK}\), and \(W_{UV}\) is absorbed with \(W_{O}\). For MLA attention without Matrix Absorption (
head_dim_qk=192
andhead_dim_vo=128
, which is used in prefilling self-attention stage), please useflashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper
.More information about The Paged KV-Cache layout in MLA is explained in our tutorial MLA Page Layout.
For more details about the MLA computation, Matrix Absorption and FlashInfer’s MLA implementation, please refer to our blog post.
Example
>>> import torch >>> import flashinfer >>> num_local_heads = 128 >>> batch_size = 114 >>> head_dim_ckv = 512 >>> head_dim_kpe = 64 >>> page_size = 1 >>> mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( ... torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0), ... backend="fa2" ... ) >>> q_indptr = torch.arange(0, batch_size + 1).to(0).int() # for decode, each query length is 1 >>> kv_lens = torch.full((batch_size,), 999, dtype=torch.int32).to(0) >>> kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * 999 >>> kv_indices = torch.arange(0, batch_size * 999).to(0).int() >>> q_nope = torch.randn( ... batch_size * 1, num_local_heads, head_dim_ckv, dtype=torch.bfloat16, device="cuda" ... ) >>> q_pe = torch.zeros( ... batch_size * 1, num_local_heads, head_dim_kpe, dtype=torch.bfloat16, device="cuda" ... ) >>> ckv = torch.randn( ... batch_size * 999, 1, head_dim_ckv, dtype=torch.bfloat16, device="cuda" ... ) >>> kpe = torch.zeros( ... batch_size * 999, 1, head_dim_kpe, dtype=torch.bfloat16, device="cuda" ... ) >>> sm_scale = 1.0 / ((128 + 64) ** 0.5) # use head dimension before matrix absorption >>> mla_wrapper.plan( ... q_indptr, ... kv_indptr, ... kv_indices, ... kv_lens, ... num_local_heads, ... head_dim_ckv, ... head_dim_kpe, ... page_size, ... False, # causal ... sm_scale, ... q_nope.dtype, ... ckv.dtype, ... ) >>> o = mla_wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False) >>> o.shape torch.Size([114, 128, 512])
- __init__(float_workspace_buffer: torch.Tensor, use_cuda_graph: bool = False, qo_indptr: torch.Tensor | None = None, kv_indptr: torch.Tensor | None = None, kv_indices: torch.Tensor | None = None, kv_len_arr: torch.Tensor | None = None, backend: str = 'fa2') None ¶
Constructor for BatchMLAPagedAttentionWrapper.
- Parameters:
float_workspace_buffer (torch.Tensor) – The user reserved workspace buffer used to store intermediate attention results in split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
use_cuda_graph (bool, optional) – Whether to enable CUDA graph capture for the prefill kernels, if enabled, the auxiliary data structures will be stored in provided buffers. The
batch_size
cannot change during the lifecycle of this wrapper when CUDAGraph is enabled.qo_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
qo_indptr
array, the size of the buffer should be[batch_size + 1]
. This argument is only effective whenuse_cuda_graph
isTrue
.kv_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
kv_indptr
array, the size of the buffer should be[batch_size + 1]
. This argument is only effective whenuse_cuda_graph
isTrue
.kv_indices_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
kv_indices
array. This argument is only effective whenuse_cuda_graph
isTrue
.kv_len_arr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
kv_len_arr
array, the size of the buffer should be[batch_size]
. This argument is only effective whenuse_cuda_graph
isTrue
.backend (str) – The implementation backend, default is “fa2”.
- plan(qo_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, kv_len_arr: torch.Tensor, num_heads: int, head_dim_ckv: int, head_dim_kpe: int, page_size: int, causal: bool, sm_scale: float, q_data_type: torch.dtype, kv_data_type: torch.dtype) None ¶
Plan the MLA attention computation.
- Parameters:
qo_indptr (torch.Tensor) – The indptr of the query/output tensor, shape:
[batch_size + 1]
. For decoding attention, the length of each query is 1, and the content of the tensor should be[0, 1, 2, ..., batch_size]
.kv_indptr (torch.Tensor) – The indptr of the paged kv-cache, shape:
[batch_size + 1]
.kv_indices (torch.Tensor) – The page indices of the paged kv-cache, shape:
[kv_indptr[-1]]
or larger.kv_len_arr (torch.Tensor) – The query length of each request, shape:
[batch_size]
.num_heads (int) – The number of heads in query/output tensor.
head_dim_ckv (int) – The head dimension of compressed-kv.
head_dim_kpe (int) – The head dimension for rope k-cache.
page_size (int) – The page size of the paged kv-cache.
causal (bool) – Whether to use causal attention.
sm_scale (float) – The scale factor for softmax operation.
q_data_type (torch.dtype) – The data type of the query tensor.
kv_data_type (torch.dtype) – The data type of the kv-cache tensor.
- run(q_nope: torch.Tensor, q_pe: torch.Tensor, ckv_cache: torch.Tensor, kpe_cache: torch.Tensor, return_lse: Literal[False] = False) torch.Tensor ¶
- run(q_nope: torch.Tensor, q_pe: torch.Tensor, ckv_cache: torch.Tensor, kpe_cache: torch.Tensor, return_lse: Literal[True] = True) Tuple[torch.Tensor, torch.Tensor]
Run the MLA attention computation.
- Parameters:
q_nope (torch.Tensor) – The query tensor without rope, shape:
[batch_size, num_heads, head_dim_ckv]
.q_pe (torch.Tensor) – The rope part of the query tensor, shape:
[batch_size, num_heads, head_dim_kpe]
.ckv_cache (torch.Tensor) – The compressed kv-cache tensor (without rope), shape:
[num_pages, page_size, head_dim_ckv]
.head_dim_ckv
is 512 in DeepSeek v2/v3 models.kpe_cache (torch.Tensor) – The rope part of the kv-cache tensor, shape:
[num_pages, page_size, head_dim_kpe]
.head_dim_kpe
is 64 in DeepSeek v2/v3 models.out (Optional[torch.Tensor]) – The output tensor, if not provided, will be allocated internally.
lse (Optional[torch.Tensor]) – The log-sum-exp of attention logits, if not provided, will be allocated internally.
return_lse (bool, optional) – Whether to return the log-sum-exp value, default is False.