FlashInfer Attention Kernels¶
flashinfer.decode¶
Single Request Decoding¶
Decode attention with KV Cache for single request, return attention output. |
|
Single-request decode using a pre-compiled JIT module. |
Batch Decoding¶
|
Batched decode attention with paged KV cache, backed by cuDNN SDPA. |
|
|
|
- class flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(float_workspace_buffer: Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, use_tensor_cores: bool = False, paged_kv_indptr_buffer: Tensor | None = None, paged_kv_indices_buffer: Tensor | None = None, paged_kv_last_page_len_buffer: Tensor | None = None, backend: str = 'auto', jit_args: List[Any] | None = None)¶
Wrapper class for decode attention with paged kv-cache (first proposed in vLLM) for batch of requests.
Check our tutorial for page table layout.
Examples
>>> import torch >>> import flashinfer >>> num_layers = 32 >>> num_qo_heads = 64 >>> num_kv_heads = 8 >>> head_dim = 128 >>> max_num_pages = 128 >>> page_size = 16 >>> # allocate 128MB workspace buffer >>> workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") >>> decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( ... workspace_buffer, "NHD" ... ) >>> batch_size = 7 >>> kv_page_indices = torch.arange(max_num_pages).int().to("cuda:0") >>> kv_page_indptr = torch.tensor( ... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0" ... ) >>> # 1 <= kv_last_page_len <= page_size >>> kv_last_page_len = torch.tensor( ... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0" ... ) >>> kv_cache_at_layer = [ ... torch.randn( ... max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ... ) for _ in range(num_layers) ... ] >>> # create auxiliary data structures for batch decode attention >>> decode_wrapper.plan( ... kv_page_indptr, ... kv_page_indices, ... kv_last_page_len, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... page_size, ... pos_encoding_mode="NONE", ... data_type=torch.float16 ... ) >>> outputs = [] >>> for i in range(num_layers): ... q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0") ... kv_cache = kv_cache_at_layer[i] ... # compute batch decode attention, reuse auxiliary data structures for all layers ... o = decode_wrapper.run(q, kv_cache) ... outputs.append(o) ... >>> outputs[0].shape torch.Size([7, 64, 128])
Note
To accelerate computation, FlashInfer’s batch decode attention creates some auxiliary data structures, these data structures can be reused across multiple batch decode attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.
- __init__(float_workspace_buffer: Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, use_tensor_cores: bool = False, paged_kv_indptr_buffer: Tensor | None = None, paged_kv_indices_buffer: Tensor | None = None, paged_kv_last_page_len_buffer: Tensor | None = None, backend: str = 'auto', jit_args: List[Any] | None = None) None¶
Constructor of
BatchDecodeWithPagedKVCacheWrapper.- Parameters:
float_workspace_buffer (torch.Tensor. Must be initialized to 0 for its first use.) – The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
kv_layout (str) – The layout of the input k/v tensors, could be either
NHDorHND.use_cuda_graph (bool) – Whether to enable CUDAGraph for batch decode attention, if enabled, the auxiliary data structures will be stored as the provided buffers. The
batch_sizecannot change during the lifecycle of this wrapper when CUDAGraph is enabled.use_tensor_cores (bool) – Whether to use tensor cores for the computation. Will be faster for large group size in grouped query attention. Defaults to
False.paged_kv_indptr_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the indptr of the paged kv cache, the size of the buffer should be
[batch_size + 1]. Only needed whenuse_cuda_graphisTrue.paged_kv_indices_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the page indices of the paged kv cache, should be large enough to store the maximum number of page indices (
max_num_pages) during the lifecycle of this wrapper. Only needed whenuse_cuda_graphisTrue.paged_kv_last_page_len_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the number of entries in the last page, the size of the buffer should be
[batch_size]. Only needed whenuse_cuda_graphisTrue.backend (str) – The implementation backend, could be
auto/fa2/fa3/trtllm-genorcute-dsl. Defaults toauto. If set toauto, the wrapper will automatically choose the backend based on the device architecture and kernel availability. Thecute-dslbackend uses the CuTe DSL GQA decode kernel for Blackwell (SM100+) and only supports a subset of features (equal head_dim_qk/vo, no RoPE/ALiBi/soft-cap/sliding window).jit_args (Optional[List[Any]]) – If provided, the wrapper will use the provided arguments to create the JIT module, otherwise, the wrapper will use default attention implementation.
- plan(indptr: Tensor, indices: Tensor, last_page_len: Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, page_size: int, pos_encoding_mode: str = 'NONE', window_left: int = -1, logits_soft_cap: float | None = None, q_data_type: str | dtype | None = 'float16', kv_data_type: str | dtype | None = None, o_data_type: str | dtype | None = None, data_type: str | dtype | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, non_blocking: bool = True, block_tables: Tensor | None = None, seq_lens: Tensor | None = None, fixed_split_size: int | None = None, disable_split_kv: bool = False, q_len_per_req: int = 1) None¶
Plan batch decode for given problem specification.
- Parameters:
indptr (torch.Tensor) – The indptr of the paged kv cache, shape:
[batch_size + 1], dtype:torch.int32indices (torch.Tensor) – The page indices of the paged kv cache, shape:
[kv_indptr[-1]], dtype:torch.int32last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv cache, shape:
[batch_size], dtype:torch.int32num_qo_heads (int) – The number of query/output heads
num_kv_heads (int) – The number of key/value heads
head_dim (int) – The dimension of the heads
page_size (int) – The page size of the paged kv cache
pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be
NONE/ROPE_LLAMA(LLAMA style rotary embedding) /ALIBI. Defaults toNONE.window_left (int) – The left (inclusive) window size for the attention window, when set to
-1, the window size will be set to the full length of the sequence. Defaults to-1.logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to
0. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.q_data_type (Optional[Union[str, torch.dtype]]) – The data type of the query tensor, defaults torch.float16.
kv_data_type (Optional[Union[str, torch.dtype]]) – The data type of the key/value tensor. If None, will be set to
q_data_type. Defaults toNone.o_data_type (Optional[Union[str, torch.dtype]]) – The data type of the output tensor. If None, will be set to
q_data_type. For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.data_type (Optional[Union[str, torch.dtype]]) – The data type of both the query and key/value tensors. Defaults to torch.float16. data_type is deprecated, please use q_data_type and kv_data_type instead.
sm_scale (Optional[float]) – Softmax scale. If
None, defaults to1 / sqrt(head_dim). Cached on the wrapper and reused atrun()time.rope_scale (Optional[float]) – Scale factor applied during RoPE interpolation. Only consulted when
pos_encoding_mode != "NONE". Defaults to1.0whenNone.rope_theta (Optional[float]) – Base value for the RoPE frequencies. Only consulted when
pos_encoding_mode != "NONE". Defaults to1e4whenNone.non_blocking (bool) – Whether to copy the input tensors to the device asynchronously, defaults to
True.seq_lens (Optional[torch.Tensor]) – A uint32 1D tensor indicating the kv sequence length of each prompt. shape:
[batch_size].block_tables (Optional[torch.Tensor]) – A uint32 2D tensor indicating the block table of each prompt. shape:
[batch_size, max_num_blocks_per_seq].fixed_split_size (Optional[int],) – The fixed split size for FA2 split-kv decode, in pages. Only supported by tensor core decode for now. Recommend setting to the average sequence length of your workload. When enabled for FA2, will lead to deterministic softmax score reduction in the merge_states kernel, and therefore batch-size invariant outputs. See https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ Note that compatibility with CUDA graph is NOT guaranteed, as even when bs is fixed, kv seq len can change and lead to a varied number of launched CTAs.
disable_split_kv (bool,) – Whether to disable the split-kv for determinism in CUDA Graph, defaults to
False.q_len_per_req (int) – The number of query tokens per request. Defaults to
1.
Note
The
plan()method should be called before anyrun()orrun_return_lse()calls, auxiliary data structures will be created during this call and cached for multiple run calls.The
num_qo_headsmust be a multiple ofnum_kv_heads. Ifnum_qo_headsis not equal tonum_kv_heads, the function will use grouped query attention.The
plan()method cannot be used in Cuda Graph or intorch.compile.
- reset_workspace_buffer(float_workspace_buffer: Tensor, int_workspace_buffer: Tensor) None¶
Reset the workspace buffer.
- Parameters:
float_workspace_buffer (torch.Tensor) – The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors.
int_workspace_buffer (torch.Tensor) – The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors.
- run(q: Tensor, paged_kv_cache: torch.Tensor | Tuple[torch.Tensor, torch.Tensor], *args, q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, out: Tensor | None = None, lse: Tensor | None = None, return_lse: Literal[False] = False, enable_pdl: bool | None = None, window_left: int | None = None, sinks: Tensor | None = None, q_len_per_req: int | None = None, skip_softmax_threshold_scale_factor: float | None = None, kv_cache_sf: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None = None) Tensor¶
- run(q: Tensor, paged_kv_cache: torch.Tensor | Tuple[torch.Tensor, torch.Tensor], *args, q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, out: Tensor | None = None, lse: Tensor | None = None, return_lse: Literal[True] = True, enable_pdl: bool | None = None, window_left: int | None = None, sinks: Tensor | None = None, q_len_per_req: int | None = None, skip_softmax_threshold_scale_factor: float | None = None, kv_cache_sf: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None = None) Tuple[Tensor, Tensor]
Compute batch decode attention between query and paged kv cache.
- Parameters:
q (torch.Tensor) – The query tensor, shape:
[batch_size * q_len_per_req, num_qo_heads, head_dim]q_len_per_req doesn’t need to match the value passed to plan()paged_kv_cache (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) –
The paged KV-Cache stored as a tuple of tensors or a single tensor:
a tuple
(k_cache, v_cache)of 4-D tensors, each with shape:[max_num_pages, page_size, num_kv_heads, head_dim]ifkv_layoutisNHD, and[max_num_pages, num_kv_heads, page_size, head_dim]ifkv_layoutisHND.a single 5-D tensor with shape:
[max_num_pages, 2, page_size, num_kv_heads, head_dim]ifkv_layoutisNHD, and[max_num_pages, 2, num_kv_heads, page_size, head_dim]ifkv_layoutisHND. Wherepaged_kv_cache[:, 0]is the key-cache andpaged_kv_cache[:, 1]is the value-cache.
*args – Additional arguments for the custom kernel.
q_scale (Optional[float]) – The calibration scale of query for fp8 input, if not provided, will be set to
1.0.k_scale (Optional[float]) – The calibration scale of key for fp8 or nvfp4 input, if not provided, will be set to
1.0.v_scale (Optional[float]) – The calibration scale of value for fp8 or nvfp4 input, if not provided, will be set to
1.0.out (Optional[torch.Tensor]) – The output tensor, if not provided, will be allocated internally. Must be zero-init for cute-dsl backend.
lse (Optional[torch.Tensor]) – The log-sum-exp of attention logits, if not provided, will be allocated internally.
return_lse (bool) – Whether to return the logsumexp of attention scores, defaults to
False.enable_pdl (bool) – Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization Only supported for >= sm90, and currently only for FA2 and CUDA core decode.
window_left (Optional[int]) – Per-call sliding-window bound. Must either be
None(default, inherit the value fromplan()) or equal the value passed toplan()(the kernel asserts this). Passing-1toplan()disables the sliding window for the entire batch.sinks (Optional[torch.Tensor]) – Per-head attention sink logits, shape
[num_qo_heads]. When provided,sinks[head_idx]is appended to each row of the softmax denominator (Streaming-LLM / Attention-Sinks). The dtype requirement is backend-specific and validated by the underlying kernel; passNoneto disable. Not supported by thecute-dslbackend.q_len_per_req (Optional[int]) – DEPRECATED — pass to
plan()instead. When provided here, emits aDeprecationWarningand is used to validate the run-time value inferred from q.size(0). Scheduled for removal in a future release.skip_softmax_threshold_scale_factor (Optional[float] = None) – threshold scale factor for skipping softmax operations. Providing a value for this parameter enables skip-softmax sparsity as described in: https://arxiv.org/abs/2512.12087 If no value is provided, then standard attention is used. Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation. The actual threshold value equals the provided threshold_scale_factor divided by the context length.
kv_cache_sf (Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]) –
Per-block scale factors for NVFP4 KV cache. Accepts the same formats as
paged_kv_cache:a tuple
(k_scales, v_scales)of 4-D tensors, each with shape:[num_pages, page_size, num_kv_heads, head_dim // 16]ifkv_layoutisNHD, and[num_pages, num_kv_heads, page_size, head_dim // 16]ifkv_layoutisHND.a single 5-D tensor with shape:
[num_pages, 2, page_size, num_kv_heads, head_dim // 16]ifkv_layoutisNHD, and[num_pages, 2, num_kv_heads, page_size, head_dim // 16]ifkv_layoutisHND, where dim 1 holds k (index 0) and v (index 1) scales.
Both tensors have dtype
torch.float8_e4m3fn.Currently, NVFP4 KV supports fa2 and trtllm-gen backend.
- Returns:
If
return_lseisFalse, the attention output, shape:[batch_size, num_qo_heads, head_dim]. Ifreturn_lseisTrue, a tuple of two tensors:attention output, shape:
[batch_size, num_qo_heads, head_dim]logsumexp of attention scores, shape:
[batch_size, num_qo_heads].
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
- class flashinfer.decode.BatchDecodeMlaWithPagedKVCacheWrapper(float_workspace_buffer: Tensor, use_cuda_graph: bool = False, use_tensor_cores: bool = False, paged_kv_indptr_buffer: Tensor | None = None, paged_kv_indices_buffer: Tensor | None = None, paged_kv_last_page_len_buffer: Tensor | None = None)¶
Warning: this class is deprecated and will be removed in a future release. Please use
flashinfer.mla.BatchMLAPagedAttentionWrapperinstead, which provides a more efficient and general MLA implementation that supports decode and incremental prefill.- __init__(float_workspace_buffer: Tensor, use_cuda_graph: bool = False, use_tensor_cores: bool = False, paged_kv_indptr_buffer: Tensor | None = None, paged_kv_indices_buffer: Tensor | None = None, paged_kv_last_page_len_buffer: Tensor | None = None) None¶
Constructor of
BatchDecodeWithPagedKVCacheWrapper.- Parameters:
float_workspace_buffer (torch.Tensor) – The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
use_cuda_graph (bool) – Whether to enable CUDAGraph for batch decode attention, if enabled, the auxiliary data structures will be stored as the provided buffers. The
batch_sizecannot change during the lifecycle of this wrapper when CUDAGraph is enabled.use_tensor_cores (bool) – Whether to use tensor cores for the computation. Will be faster for large group size in grouped query attention. Defaults to
False.paged_kv_indptr_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the indptr of the paged kv cache, the size of the buffer should be
[batch_size + 1]. Only needed whenuse_cuda_graphisTrue.paged_kv_indices_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the page indices of the paged kv cache, should be large enough to store the maximum number of page indices (
max_num_pages) during the lifecycle of this wrapper. Only needed whenuse_cuda_graphisTrue.paged_kv_last_page_len_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the number of entries in the last page, the size of the buffer should be
[batch_size]. Only needed whenuse_cuda_graphisTrue.
- plan(indptr: Tensor, indices: Tensor, last_page_len: Tensor, num_qo_heads: int, head_dim_compressed_kv: int, page_size: int, sm_scale: float, window_left: int = -1, logits_soft_cap: float | None = None, data_type: str | dtype = 'float16', q_data_type: str | dtype | None = None, rope_scale: float | None = None, rope_theta: float | None = None) None¶
Plan batch decode for given problem specification.
- Parameters:
indptr (torch.Tensor) – The indptr of the paged kv cache, shape:
[batch_size + 1], dtype:torch.int32indices (torch.Tensor) – The page indices of the paged kv cache, shape:
[qo_indptr[-1]], dtype:torch.int32last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv cache, shape:
[batch_size], dtype:torch.int32num_qo_heads (int) – The number of query/output heads
head_dim_compressed_kv (int) – The dimension of the compressed kv, is also kv_lora_rank
page_size (int) – The page size of the paged kv cache
sm_scale (float) – The scale of softmax, should be
1 / sqrt(qk_nope_head_dim + qk_rope_head_dim)window_left (int) – The left (inclusive) window size for the attention window, when set to
-1, the window size will be set to the full length of the sequence. Defaults to-1.logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to
0. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.data_type (Union[str, torch.dtype]) – The data type of the paged kv cache. Defaults to
float16.q_data_type (Optional[Union[str, torch.dtype]]) – The data type of the query tensor. If None, will be set to
data_type. Defaults toNone.rope_scale (Optional[float]) – Scale factor applied during RoPE interpolation for the rope-portion of the MLA query. Defaults to
1.0whenNone.rope_theta (Optional[float]) – Base value for the RoPE frequencies of the rope-portion of the MLA query. Defaults to
1e4whenNone.
- reset_workspace_buffer(float_workspace_buffer: Tensor, int_workspace_buffer: Tensor) None¶
Reset the workspace buffer.
- Parameters:
float_workspace_buffer (torch.Tensor) – The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors.
int_workspace_buffer (torch.Tensor) – The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors.
- run(q_nope: Tensor, q_pe: Tensor, paged_ckv_cache: Tensor, paged_kpe_cache: Tensor, q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, out: Tensor | None = None, lse: Tensor | None = None, return_lse: bool = False, enable_pdl: bool = False) Tensor | Tuple[Tensor, Tensor]¶
Compute batch decode attention between query and paged kv cache.
- Parameters:
q_nope (torch.Tensor) – The query tensor not related to ROPE, shape:
[batch_size, num_qo_heads, head_dim_ckv]q_pe (torch.Tensor) – The query tensor related to ROPE, shape:
[batch_size, num_qo_heads, head_dim_kpe]paged_ckv_cache (torch.Tensor) – The paged compressed-KV-Cache stored as a single tensor: * 3-D tensors, each with shape:
[max_num_pages, page_size, head_dim_ckv].paged_kpe_cache (torch.Tensor) – The paged k-pe-Cache stored as a single tensor: * 3-D tensors, each with shape:
[max_num_pages, page_size, head_dim_kpe].q_scale (Optional[float]) – The calibration scale of query for fp8 input, if not provided, will be set to
1.0.k_scale (Optional[float]) – The calibration scale of key for fp8 input, if not provided, will be set to
1.0.v_scale (Optional[float]) – The calibration scale of value for fp8 input, if not provided, will be set to
1.0.out (Optional[torch.Tensor]) – The output tensor, if not provided, will be allocated internally.
lse (Optional[torch.Tensor]) – The log-sum-exp of attention logits, if not provided, will be allocated internally.
return_lse (bool) – Whether to return the logsumexp of attention scores, defaults to
False.enable_pdl (bool) – Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization Only supported for >= sm90, and currently only for FA2 and CUDA core decode.
- Returns:
If
return_lseisFalse, the attention output, shape:[batch_size, num_qo_heads, head_dim]. Ifreturn_lseisTrue, a tuple of two tensors:attention output, shape:
[batch_size, num_qo_heads, head_dim]logsumexp of attention scores, shape:
[batch_size, num_qo_heads].
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
- class flashinfer.decode.CUDAGraphBatchDecodeWithPagedKVCacheWrapper(workspace_buffer: Tensor, indptr_buffer: Tensor, indices_buffer: Tensor, last_page_len_buffer: Tensor, kv_layout: str = 'NHD', use_tensor_cores: bool = False)¶
CUDAGraph-compatible Wrapper class for decode attention with paged kv-cache (first proposed in vLLM) for batch of requests.
Note that this wrapper may not be as efficient as
BatchDecodeWithPagedKVCacheWrapperbecause we won’t dispatch to different kernels for different batch sizes/sequence lengths/etc to accommodate the CUDAGraph requirement.Check our tutorial for page table layout.
Note
The
plan()method could not be captured by CUDAGraph.See also
- __init__(workspace_buffer: Tensor, indptr_buffer: Tensor, indices_buffer: Tensor, last_page_len_buffer: Tensor, kv_layout: str = 'NHD', use_tensor_cores: bool = False) None¶
Constructor of
BatchDecodeWithPagedKVCacheWrapper.- Parameters:
workspace_buffer (torch.Tensor) – The user reserved workspace buffer on GPU used to store auxiliary data structures, recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
indptr_buffer (torch.Tensor) – The user reserved buffer on GPU to store the indptr of the paged kv cache, should be large enough to store the indptr of maximum batch size (
[max_batch_size + 1]) during the lifecycle of this wrapper.indices_buffer (torch.Tensor) – The user reserved buffer on GPU to store the page indices of the paged kv cache, should be large enough to store the maximum number of page indices (
max_num_pages) during the lifecycle of this wrapper.last_page_len_buffer (torch.Tensor) – The user reserved buffer on GPU to store the number of entries in the last page, should be large enough to store the maximum batch size (
[max_batch_size]) during the lifecycle of this wrapper.use_tensor_cores (bool) – Whether to use tensor cores for the computation. Will be faster for large group size in grouped query attention. Defaults to
False.kv_layout (str) – The layout of the input k/v tensors, could be either
NHDorHND.
XQA¶
|
Apply attention with paged KV cache using XQA kernel. :param q: Query tensor with shape |
|
Apply attention with paged KV cache using XQA MLA (Multi-Head Latent Attention) kernel. :param q: Query tensor with shape |
flashinfer.prefill¶
Attention kernels for prefill & append attention in both single request and batch serving setting.
Single Request Prefill/Append Attention¶
Prefill/Append attention with KV cache for single request, return the attention output. |
|
Convenience wrapper for |
|
Single-request prefill / append attention using a pre-compiled JIT module. |
Batch Prefill/Append Attention¶
|
Batched prefill attention with paged KV cache, backed by cuDNN SDPA. |
|
|
|
|
|
|
|
TRT-LLM FMHAv2 prefill attention. |
- class flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(float_workspace_buffer: Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, qo_indptr_buf: Tensor | None = None, paged_kv_indptr_buf: Tensor | None = None, paged_kv_indices_buf: Tensor | None = None, paged_kv_last_page_len_buf: Tensor | None = None, custom_mask_buf: Tensor | None = None, mask_indptr_buf: Tensor | None = None, backend: str = 'auto', jit_args: List[Any] | None = None, jit_kwargs: Dict[str, Any] | None = None)¶
Wrapper class for prefill/append attention with paged kv-cache for batch of requests.
Check our tutorial for page table layout.
Example
>>> import torch >>> import flashinfer >>> num_layers = 32 >>> num_qo_heads = 64 >>> num_kv_heads = 16 >>> head_dim = 128 >>> max_num_pages = 128 >>> page_size = 16 >>> # allocate 128MB workspace buffer >>> workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") >>> prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( ... workspace_buffer, "NHD" ... ) >>> batch_size = 7 >>> nnz_qo = 100 >>> qo_indptr = torch.tensor( ... [0, 33, 44, 55, 66, 77, 88, nnz_qo], dtype=torch.int32, device="cuda:0" ... ) >>> paged_kv_indices = torch.arange(max_num_pages).int().to("cuda:0") >>> paged_kv_indptr = torch.tensor( ... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0" ... ) >>> # 1 <= paged_kv_last_page_len <= page_size >>> paged_kv_last_page_len = torch.tensor( ... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0" ... ) >>> q_at_layer = torch.randn(num_layers, nnz_qo, num_qo_heads, head_dim).half().to("cuda:0") >>> kv_cache_at_layer = torch.randn( ... num_layers, max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ... ) >>> # create auxiliary data structures for batch prefill attention >>> prefill_wrapper.plan( ... qo_indptr, ... paged_kv_indptr, ... paged_kv_indices, ... paged_kv_last_page_len, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... page_size, ... causal=True, ... ) >>> outputs = [] >>> for i in range(num_layers): ... q = q_at_layer[i] ... kv_cache = kv_cache_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures ... o = prefill_wrapper.run(q, kv_cache) ... outputs.append(o) ... >>> outputs[0].shape torch.Size([100, 64, 128]) >>> >>> # below is another example of creating custom mask for batch prefill attention >>> mask_arr = [] >>> qo_len = (qo_indptr[1:] - qo_indptr[:-1]).cpu().tolist() >>> kv_len = (page_size * (paged_kv_indptr[1:] - paged_kv_indptr[:-1] - 1) + paged_kv_last_page_len).cpu().tolist() >>> for i in range(batch_size): ... mask_i = torch.tril( ... torch.full((qo_len[i], kv_len[i]), True, device="cuda:0"), ... diagonal=(kv_len[i] - qo_len[i]), ... ) ... mask_arr.append(mask_i.flatten()) ... >>> mask = torch.cat(mask_arr, dim=0) >>> prefill_wrapper.plan( ... qo_indptr, ... paged_kv_indptr, ... paged_kv_indices, ... paged_kv_last_page_len, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... page_size, ... custom_mask=mask, ... ) >>> for i in range(num_layers): ... q = q_at_layer[i] ... kv_cache = kv_cache_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures ... o_custom = prefill_wrapper.run(q, kv_cache) ... assert torch.allclose(o_custom, outputs[i], rtol=1e-3, atol=1e-3) ...
Note
To accelerate computation, FlashInfer’s batch prefill/append attention operators create some auxiliary data structures, these data structures can be reused across multiple prefill/append attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.
- __init__(float_workspace_buffer: Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, qo_indptr_buf: Tensor | None = None, paged_kv_indptr_buf: Tensor | None = None, paged_kv_indices_buf: Tensor | None = None, paged_kv_last_page_len_buf: Tensor | None = None, custom_mask_buf: Tensor | None = None, mask_indptr_buf: Tensor | None = None, backend: str = 'auto', jit_args: List[Any] | None = None, jit_kwargs: Dict[str, Any] | None = None) None¶
Constructor of
BatchPrefillWithPagedKVCacheWrapper.- Parameters:
float_workspace_buffer (torch.Tensor) – The user reserved workspace buffer used to store intermediate attention results in split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
kv_layout (str) – The layout of the input k/v tensors, could be either
NHDorHND.use_cuda_graph (bool) – Whether to enable CUDA graph capture for the prefill kernels, if enabled, the auxiliary data structures will be stored in provided buffers. The
batch_sizecannot change during the lifecycle of this wrapper when CUDAGraph is enabled.qo_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
qo_indptrarray, the size of the buffer should be[batch_size + 1]. This argument is only effective whenuse_cuda_graphisTrue.paged_kv_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
paged_kv_indptrarray, the size of this buffer should be[batch_size + 1]. This argument is only effective whenuse_cuda_graphisTrue.paged_kv_indices_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
paged_kv_indicesarray, should be large enough to store the maximum possible size of thepaged_kv_indicesarray during the lifetime of the wrapper. This argument is only effective whenuse_cuda_graphisTrue.paged_kv_last_page_len_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
paged_kv_last_page_lenarray, the size of the buffer should be[batch_size]. This argument is only effective whenuse_cuda_graphisTrue.custom_mask_buf (Optional[torch.Tensor]) – The user reserved buffer to store the custom mask tensor, should be large enough to store the maximum possible size of the packed custom mask tensor during the lifetime of the wrapper. This argument is only effective when
use_cuda_graphis set toTrueand the custom mask will be used in attention computation.mask_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
mask_indptrarray, the size of the buffer should be[batch_size + 1]. This argument is only effective whenuse_cuda_graphisTrueand the custom mask will be used in attention computation.backend (str) – The implementation backend, could be
auto/fa2/fa3/cudnnortrtllm-gen. Defaults toauto. If set toauto, the wrapper will automatically choose the backend based on the device architecture and kernel availability.jit_args (Optional[List[Any]]) – If provided, the wrapper will use the provided arguments to create the JIT module, otherwise, the wrapper will use default attention implementation.
jit_kwargs (Optional[Dict[str, Any]]) – The keyword arguments to create the JIT module, defaults to None.
- plan(qo_indptr: Tensor, paged_kv_indptr: Tensor, paged_kv_indices: Tensor, paged_kv_last_page_len: Tensor, num_qo_heads: int, num_kv_heads: int, head_dim_qk: int, page_size: int, head_dim_vo: int | None = None, custom_mask: Tensor | None = None, packed_custom_mask: Tensor | None = None, causal: bool = False, pos_encoding_mode: str = 'NONE', use_fp16_qk_reduction: bool = False, sm_scale: float | None = None, window_left: int = -1, logits_soft_cap: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, q_data_type: str | dtype = 'float16', kv_data_type: str | dtype | None = None, o_data_type: str | dtype | None = None, non_blocking: bool = True, prefix_len_ptr: Tensor | None = None, token_pos_in_items_ptr: Tensor | None = None, token_pos_in_items_len: int = 0, max_item_len_ptr: Tensor | None = None, seq_lens: Tensor | None = None, seq_lens_q: Tensor | None = None, block_tables: Tensor | None = None, max_token_per_sequence: int | None = None, max_sequence_kv: int | None = None, fixed_split_size: int | None = None, disable_split_kv: bool = False) None¶
Plan batch prefill/append attention on Paged KV-Cache for given problem specification.
- Parameters:
qo_indptr (torch.Tensor) – The indptr of the query/output tensor, shape:
[batch_size + 1].paged_kv_indptr (torch.Tensor) – The indptr of the paged kv-cache, shape:
[batch_size + 1].paged_kv_indices (torch.Tensor) – The page indices of the paged kv-cache, shape:
[paged_kv_indptr[-1]].paged_kv_last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv-cache, shape:
[batch_size].num_qo_heads (int) – The number of query/output heads.
num_kv_heads (int) – The number of key/value heads.
head_dim_qk (int) – The dimension of the query/key heads.
page_size (int) – The size of each page in the paged kv-cache.
head_dim_vo (Optional[int]) – The dimension of the value/output heads, if not provided, will be set to
head_dim_qk.custom_mask (Optional[torch.Tensor]) –
The flattened boolean mask tensor, shape:
(sum(q_len[i] * k_len[i] for i in range(batch_size)). The elements in the mask tensor should be eitherTrueorFalse, whereFalsemeans the corresponding element in the attention matrix will be masked out.Please refer to the mask layout for more details about flattened layout of mask tensor.
When
custom_maskis provided, andpacked_custom_maskis not, the function will pack the custom mask tensor into a 1D packed mask tensor, which introduces additional overhead.packed_custom_mask (Optional[torch.Tensor]) – The 1D packed uint8 mask tensor, if provided, the
custom_maskwill be ignored. The packed mask tensor is generated byflashinfer.quantization.packbits().causal (bool) – Whether to apply causal mask to the attention matrix. This is only effective when
custom_maskis not provided inplan().pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be
NONE/ROPE_LLAMA(LLAMA style rotary embedding) /ALIBI. Default isNONE.use_fp16_qk_reduction (bool) – Whether to use f16 for qk reduction (faster at the cost of slight precision loss).
window_left (int) – The left (inclusive) window size for the attention window, when set to
-1, the window size will be set to the full length of the sequence. Defaults to-1.logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to
0. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.sm_scale (Optional[float]) – The scale used in softmax, if not provided, will be set to
1.0 / sqrt(head_dim).rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to
1.0.rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to
1e4.q_data_type (Union[str, torch.dtype]) – The data type of the query tensor, defaults torch.float16.
kv_data_type (Optional[Union[str, torch.dtype]]) – The data type of the key/value tensor. If None, will be set to
q_data_type.o_data_type (Optional[Union[str, torch.dtype]]) – The data type of the output tensor. If None, will be set to
q_data_type. For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.non_blocking (bool) – Whether to copy the input tensors to the device asynchronously, defaults to
True.prefix_len_ptr (Optional[torch.Tensor]) – prefix length. A uint32 1D tensor indicating the prefix length of each prompt. The tensor size is equal to the batch size.
token_pos_in_items_ptr (Optional[torch.Tensor]) – A uint16 1D tensor (it will be converted to uint16 in flashinfer) indicating the token position of each item and started from 0 (delimiter) for each item. E.g., if we have 3 items of length 3, 2, 4 respectively for this member. This vector will be looking like [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0] with 4 delimiters indexed as 0. For batch size > 1, we will concat them as 1D with zero paddings to make sure each has the same length, the padding length is defined by token_pos_in_items_len - length of the raw token_pos_in_items_ptr for each prompt.
token_pos_in_items_len (int) – zero padding length for token_pos_in_items_ptr to better handle the bsz > 1 case. Still using the above 3,2,4 example. If we set token_pos_in_items_len to be 20, it will be [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0] with 7 padded zeros. (note there’re 8 zeros in the end where the first one is the delimiter token 0 in the end of the prompt)
max_item_len_ptr (Optional[torch.Tensor]) – a uint16 vector contains the max token length of all items for each prompt
seq_lens (Optional[torch.Tensor]) – A uint32 1D tensor indicating the kv sequence length of each prompt. shape:
[batch_size].seq_lens_q (Optional[torch.Tensor]) – A uint32 1D tensor indicating the q sequence length of each prompt. shape:
[batch_size]. If not provided, will be set to the same value asseq_lens.block_tables (Optional[torch.Tensor]) – A uint32 2D tensor indicating the block table of each prompt. shape:
[batch_size, max_num_blocks_per_seq].max_token_per_sequence (Optional[int],) – Required for cudnn backend. This is the scalar max token length of each sequence.
max_sequence_kv (Optional[int],) – Required for cudnn backend. This is the scalar max sequence length of each sequence in kv cache.
fixed_split_size (Optional[int],) – The fixed split size for FA2 split-kv prefill/decode in pages. Recommend setting to the average sequence length of your workload. When enabled, will lead to deterministic softmax score reduction in the merge_states kernel, and therefore batch-size invariant outputs. See https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ Note that compatibility with CUDA graph is NOT guaranteed, as even when bs is fixed, kv seq len can change and lead to a varied number of launched CTAs.
disable_split_kv (bool,) – Whether to disable the split-kv for determinism in CUDA Graph, defaults to
False.
Note
The
plan()method should be called before anyrun()orrun_return_lse()calls, auxiliary data structures will be created during this call and cached for multiple kernel runs.The
num_qo_headsmust be a multiple ofnum_kv_heads. Ifnum_qo_headsis not equal tonum_kv_heads, the function will use grouped query attention.The
plan()method cannot be used in Cuda Graph or intorch.compile.
- reset_workspace_buffer(float_workspace_buffer: Tensor, int_workspace_buffer: Tensor) None¶
Reset the workspace buffer.
- Parameters:
float_workspace_buffer (torch.Tensor) – The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors.
int_workspace_buffer (torch.Tensor) – The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors.
- run(q: Tensor, paged_kv_cache: torch.Tensor | Tuple[torch.Tensor, torch.Tensor], *args, k_scale: float | None = None, v_scale: float | None = None, out: Tensor | None = None, lse: Tensor | None = None, return_lse: Literal[False] = False, enable_pdl: bool | None = None, window_left: int | None = None, sinks: Tensor | None = None, kv_cache_sf: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None = None, skip_softmax_threshold_scale_factor: float | None = None) Tensor¶
- run(q: Tensor, paged_kv_cache: torch.Tensor | Tuple[torch.Tensor, torch.Tensor], *args, k_scale: float | None = None, v_scale: float | None = None, out: Tensor | None = None, lse: Tensor | None = None, return_lse: Literal[True] = True, enable_pdl: bool | None = None, window_left: int | None = None, sinks: Tensor | None = None, kv_cache_sf: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None = None, skip_softmax_threshold_scale_factor: float | None = None) Tuple[Tensor, Tensor]
Compute batch prefill/append attention between query and paged kv-cache.
- Parameters:
q (torch.Tensor) – The query tensor, shape:
[qo_indptr[-1], num_qo_heads, head_dim]paged_kv_cache (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) –
The paged KV-Cache stored as a tuple of tensors or a single tensor:
a tuple
(k_cache, v_cache)of 4-D tensors, each with shape:[max_num_pages, page_size, num_kv_heads, head_dim]ifkv_layoutisNHD, and[max_num_pages, num_kv_heads, page_size, head_dim]ifkv_layoutisHND.a single 5-D tensor with shape:
[max_num_pages, 2, page_size, num_kv_heads, head_dim]ifkv_layoutisNHD, and[max_num_pages, 2, num_kv_heads, page_size, head_dim]ifkv_layoutisHND. Wherepaged_kv_cache[:, 0]is the key-cache andpaged_kv_cache[:, 1]is the value-cache.
*args – Additional arguments for custom kernels.
q_scale (Optional[Union[float, torch.Tensor]]) – The calibration scale of query for fp8 input, if not provided, will be set to
1.0.k_scale (Optional[Union[float, torch.Tensor]]) – The calibration scale of key for fp8 or nvfp4 input, if not provided, will be set to
1.0.v_scale (Optional[Union[float, torch.Tensor]]) – The calibration scale of value for fp8 or nvfp4 input, if not provided, will be set to
1.0.out (Optional[torch.Tensor]) – The output tensor, if not provided, will be allocated internally.
lse (Optional[torch.Tensor]) – The log-sum-exp of attention logits, if not provided, will be allocated internally.
return_lse (bool) – Whether to return the logsumexp of attention output
enable_pdl (bool) – Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization Only supported for >= sm90, and currently only for FA2 and CUDA core decode.
window_left (Optional[int]) – Per-call override for the left (inclusive) sliding-window size. When
None, the value supplied toplan()is used. Pass-1to disable the sliding window for this call.sinks (Optional[torch.Tensor]) – Per-head attention-sink logits. When provided, the kernel applies the attention-with-sink variant: an additional virtual token whose logit is
sinks[head_idx]is appended to each row of the softmax denominator. Shape:[num_qo_heads], dtypefloat32.kv_cache_sf (Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]) –
Per-block scale factors for NVFP4 KV cache. Accepts the same formats as
paged_kv_cache:a tuple
(k_scales, v_scales)of 4-D tensors, each with shape:[num_pages, page_size, num_kv_heads, head_dim // 16]ifkv_layoutisNHD, and[num_pages, num_kv_heads, page_size, head_dim // 16]ifkv_layoutisHND.a single 5-D tensor with shape:
[num_pages, 2, page_size, num_kv_heads, head_dim // 16]ifkv_layoutisNHD, and[num_pages, 2, num_kv_heads, page_size, head_dim // 16]ifkv_layoutisHND, where dim 1 holds k (index 0) and v (index 1) scales.
Both tensors have dtype
torch.float8_e4m3fn.k_scalesuses a linear (row-major) layout, whilev_scalesmust use TRT-LLM’s 4-token interleaved layout within each[page_size, head_dim // 16]tile if backend is trtllm-gen. Useflashinfer.fp4_quantization.nvfp4_quantize_paged_kv_cache()to produce correctly formatted scale factors.For the trtllm-gen backend with
NHDlayout, scale tensors are transposed to HND internally (incurring a copy). UseHNDfor better performance.Currently, NVFP4 KV supports fa2 and trtllm-gen backend.
skip_softmax_threshold_scale_factor (Optional[float]) – Threshold scale factor for skipping softmax operations. Providing a value enables skip-softmax sparsity as described in https://arxiv.org/abs/2512.12087. Defaults to
None(standard attention). Higher values yield faster kernels at the cost of accuracy; the effective threshold equals the supplied factor divided by the context length.
- Returns:
If
return_lseisFalse, the attention output, shape:[qo_indptr[-1], num_qo_heads, head_dim]. Ifreturn_lseisTrue, a tuple of two tensors:The attention output, shape:
[qo_indptr[-1], num_qo_heads, head_dim].The logsumexp of attention output, shape:
[qo_indptr[-1], num_qo_heads].
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
- class flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(float_workspace_buffer: Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, qo_indptr_buf: Tensor | None = None, kv_indptr_buf: Tensor | None = None, custom_mask_buf: Tensor | None = None, mask_indptr_buf: Tensor | None = None, backend: str = 'auto', jit_args: List[Any] | None = None, jit_kwargs: Dict[str, Any] | None = None)¶
Wrapper class for prefill/append attention with ragged (tensor) kv-cache for batch of requests.
Check our tutorial for ragged kv-cache layout.
Example
>>> import torch >>> import flashinfer >>> num_layers = 32 >>> num_qo_heads = 64 >>> num_kv_heads = 16 >>> head_dim = 128 >>> # allocate 128MB workspace buffer >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") >>> prefill_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( ... workspace_buffer, "NHD" ... ) >>> batch_size = 7 >>> nnz_kv = 100 >>> nnz_qo = 100 >>> qo_indptr = torch.tensor( ... [0, 33, 44, 55, 66, 77, 88, nnz_qo], dtype=torch.int32, device="cuda:0" ... ) >>> kv_indptr = qo_indptr.clone() >>> q_at_layer = torch.randn(num_layers, nnz_qo, num_qo_heads, head_dim).half().to("cuda:0") >>> k_at_layer = torch.randn(num_layers, nnz_kv, num_kv_heads, head_dim).half().to("cuda:0") >>> v_at_layer = torch.randn(num_layers, nnz_kv, num_kv_heads, head_dim).half().to("cuda:0") >>> # create auxiliary data structures for batch prefill attention >>> prefill_wrapper.plan( ... qo_indptr, ... kv_indptr, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... causal=True, ... ) >>> outputs = [] >>> for i in range(num_layers): ... q = q_at_layer[i] ... k = k_at_layer[i] ... v = v_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures ... o = prefill_wrapper.run(q, k, v) ... outputs.append(o) ... >>> outputs[0].shape torch.Size([100, 64, 128]) >>> >>> # below is another example of creating custom mask for batch prefill attention >>> mask_arr = [] >>> qo_len = (qo_indptr[1:] - qo_indptr[:-1]).cpu().tolist() >>> kv_len = (kv_indptr[1:] - kv_indptr[:-1]).cpu().tolist() >>> for i in range(batch_size): ... mask_i = torch.tril( ... torch.full((qo_len[i], kv_len[i]), True, device="cuda:0"), ... diagonal=(kv_len[i] - qo_len[i]), ... ) ... mask_arr.append(mask_i.flatten()) ... >>> mask = torch.cat(mask_arr, dim=0) >>> prefill_wrapper.plan( ... qo_indptr, ... kv_indptr, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... custom_mask=mask ... ) >>> outputs_custom_mask = [] >>> for i in range(num_layers): ... q = q_at_layer[i] ... k = k_at_layer[i] ... v = v_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures ... o_custom = prefill_wrapper.run(q, k, v) ... assert torch.allclose(o_custom, outputs[i], rtol=1e-3, atol=1e-3) ... >>> outputs_custom_mask[0].shape torch.Size([100, 64, 128])
Note
To accelerate computation, FlashInfer’s batch prefill/append attention operators create some auxiliary data structures, these data structures can be reused across multiple prefill/append attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.
- __init__(float_workspace_buffer: Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, qo_indptr_buf: Tensor | None = None, kv_indptr_buf: Tensor | None = None, custom_mask_buf: Tensor | None = None, mask_indptr_buf: Tensor | None = None, backend: str = 'auto', jit_args: List[Any] | None = None, jit_kwargs: Dict[str, Any] | None = None) None¶
Constructor of
BatchPrefillWithRaggedKVCacheWrapper.- Parameters:
float_workspace_buffer (torch.Tensor) – The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
kv_layout (str) – The layout of the input k/v tensors, could be either
NHDorHND.use_cuda_graph (bool) – Whether to enable CUDA graph capture for the prefill kernels, if enabled, the auxiliary data structures will be stored as the provided buffers.
qo_indptr_buf (Optional[torch.Tensor]) – The user reserved GPU buffer to store the
qo_indptrarray, the size of the buffer should be[batch_size + 1]. This argument is only effective whenuse_cuda_graphisTrue.kv_indptr_buf (Optional[torch.Tensor]) – The user reserved GPU buffer to store the
kv_indptrarray, the size of the buffer should be[batch_size + 1]. This argument is only effective whenuse_cuda_graphisTrue.custom_mask_buf (Optional[torch.Tensor]) – The user reserved GPU buffer to store the custom mask tensor, should be large enough to store the maximum possible size of the packed custom mask tensor during the lifetime of the wrapper. This argument is only effective when
use_cuda_graphisTrueand custom mask will be used in attention computation.mask_indptr_buf (Optional[torch.Tensor]) – The user reserved GPU buffer to store the
mask_indptrarray, the size of the buffer should be[batch_size]. This argument is only effective whenuse_cuda_graphisTrueand custom mask will be used in attention computation.backend (str) – The implementation backend, could be
auto/fa2/fa3/cudnn/cutlassorcute-dsl. Defaults toauto. If set toauto, the wrapper will automatically choose the backend based on the device architecture and kernel availability. Thecute-dslbackend uses the CuTe DSL attention kernel for Blackwell (SM100+).jit_args (Optional[List[Any]]) – If provided, the wrapper will use the provided arguments to create the JIT module, otherwise, the wrapper will use default attention implementation.
jit_kwargs (Optional[Dict[str, Any]]) – The keyword arguments to create the JIT module, defaults to None.
- plan(qo_indptr: Tensor, kv_indptr: Tensor, num_qo_heads: int, num_kv_heads: int, head_dim_qk: int, head_dim_vo: int | None = None, custom_mask: Tensor | None = None, packed_custom_mask: Tensor | None = None, causal: bool = False, pos_encoding_mode: str = 'NONE', use_fp16_qk_reduction: bool = False, window_left: int = -1, logits_soft_cap: float | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, q_data_type: str | dtype = 'float16', kv_data_type: str | dtype | None = None, o_data_type: str | dtype | None = None, non_blocking: bool = True, prefix_len_ptr: Tensor | None = None, token_pos_in_items_ptr: Tensor | None = None, token_pos_in_items_len: int = 0, max_item_len_ptr: Tensor | None = None, fixed_split_size: int | None = None, disable_split_kv: bool = False, seq_lens: Tensor | None = None, seq_lens_q: Tensor | None = None, max_token_per_sequence: int | None = None, max_sequence_kv: int | None = None, v_indptr: Tensor | None = None, o_indptr: Tensor | None = None) None¶
Plan batch prefill/append attention on Ragged KV-Cache for given problem specification.
- Parameters:
qo_indptr (torch.Tensor) – The indptr of the query/output tensor, shape:
[batch_size + 1].kv_indptr (torch.Tensor) – The indptr of the key/value tensor, shape:
[batch_size + 1].num_qo_heads (int) – The number of query/output heads.
num_kv_heads (int) – The number of key/value heads.
head_dim_qk (int) – The dimension of the heads on query/key tensor.
head_dim_vo (Optional[int]) – The dimension of the heads on value/output tensor. If not provided, will be set to
head_dim_qk.custom_mask (Optional[torch.Tensor]) –
The flattened boolean mask tensor, shape:
(sum(q_len[i] * k_len[i] for i in range(batch_size)). The elements in the mask tensor should be eitherTrueorFalse, whereFalsemeans the corresponding element in the attention matrix will be masked out.Please refer to the mask layout for more details about flattened layout of mask tensor.
When
custom_maskis provided, andpacked_custom_maskis not, the function will pack the custom mask tensor into a 1D packed mask tensor, which introduces additional overhead.packed_custom_mask (Optional[torch.Tensor]) –
The 1D packed uint8 mask tensor, if provided, the
custom_maskwill be ignored. The packed mask tensor is generated byflashinfer.quantization.packbits().If provided, the custom mask will be added to the attention matrix before softmax and after scaling. The mask tensor should be in the same device as the input tensors.
causal (bool) – Whether to apply causal mask to the attention matrix. This argument is ignored if
maskis provided inplan().pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be
NONE/ROPE_LLAMA(LLAMA style rotary embedding) /ALIBI. Default isNONE.use_fp16_qk_reduction (bool) – Whether to use f16 for qk reduction (faster at the cost of slight precision loss).
window_left (int) – The left (inclusive) window size for the attention window, when set to
-1, the window size will be set to the full length of the sequence. Defaults to-1.logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to
0. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.sm_scale (Optional[float]) – The scale used in softmax, if not provided, will be set to
1.0 / sqrt(head_dim_qk).rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to
1.0.rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to
1e4.q_data_type (Union[str, torch.dtype]) – The data type of the query tensor, defaults to torch.float16.
kv_data_type (Optional[Union[str, torch.dtype]]) – The data type of the key/value tensor. If None, will be set to
q_data_type.o_data_type (Optional[Union[str, torch.dtype]]) – The data type of the output tensor. If None, will be set to
q_data_type. For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.non_blocking (bool) – Whether to copy the input tensors to the device asynchronously, defaults to
True.prefix_len_ptr (Optional[torch.Tensor]) – prefix length. A uint32 1D tensor indicating the prefix length of each prompt. The tensor size is equal to the batch size.
token_pos_in_items_ptr (Optional[torch.Tensor]) – A uint16 1D tensor (it will be converted to uint16 in flashinfer) indicating the token position of each item and started from 0 (delimiter) for each item. E.g., if we have 3 items of length 3, 2, 4 respectively for this member. This vector will be looking like [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0] with 4 delimiters indexed as 0. For batch size > 1, we will concat them as 1D with zero paddings to make sure each has the same length, the padding length is defined by token_pos_in_items_len - length of the raw token_pos_in_items_ptr for each prompt.
token_pos_in_items_len (int) – zero padding length for token_pos_in_items_ptr to better handle the bsz > 1 case. Still using the above 3,2,4 example. If we set token_pos_in_items_len to be 20, it will be [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0] with 7 padded zeros. (note there’re 8 zeros in the end where the first one is the delimiter token 0 in the end of the prompt)
max_item_len_ptr (Optional[torch.Tensor]) – a uint16 vector contains the max token length of all items for each prompt
fixed_split_size (Optional[int],) – The fixed split size for split-kv FA2 prefill/decode, in pages. Recommend setting to the average sequence length of your workload. When enabled, will lead to deterministic softmax score reduction in the merge_states kernel, and therefore batch-size invariant outputs. See https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ Note that compatibility with CUDA graph is NOT guaranteed, as even when bs is fixed, kv seq len can change and lead to a varied number of launched CTAs.
disable_split_kv (bool,) – Whether to disable the split-kv for determinism in CUDA Graph, defaults to
False.seq_lens (Optional[torch.Tensor]) – A uint32 1D tensor indicating the kv sequence length of each prompt. shape:
[batch_size].seq_lens_q (Optional[torch.Tensor]) – A uint32 1D tensor indicating the q sequence length of each prompt. shape:
[batch_size]. If not provided, will be set to the same value asseq_lens.max_token_per_sequence (Optional[int],) – Required for cudnn backend. This is the scalar max token length of each sequence.
max_sequence_kv (Optional[int],) – Required for cudnn backend. This is the scalar max sequence length of each sequence in kv cache.
v_indptr (Optional[torch.Tensor]) – Required for cudnn backend. This is the indptr of the value tensor.
o_indptr (Optional[torch.Tensor]) – Required for cudnn backend. This is the indptr of the output tensor.
Note
The
plan()method should be called before anyrun()orrun_return_lse()calls, auxiliary data structures will be created during this plan call and cached for multiple kernel runs.The
num_qo_headsmust be a multiple ofnum_kv_heads. Ifnum_qo_headsis not equal tonum_kv_heads, the function will use grouped query attention.The
plan()method cannot be used in Cuda Graph or intorch.compile.
- reset_workspace_buffer(float_workspace_buffer: Tensor, int_workspace_buffer) None¶
Reset the workspace buffer.
- Parameters:
float_workspace_buffer (torch.Tensor) – The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors.
int_workspace_buffer (torch.Tensor) – The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors.
- run(q: Tensor, k: Tensor, v: Tensor, *args, out: Tensor | None = None, lse: Tensor | None = None, return_lse: Literal[False] = False, enable_pdl: bool | None = None, kv_cache_sf: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None = None) Tensor¶
- run(q: Tensor, k: Tensor, v: Tensor, *args, out: Tensor | None = None, lse: Tensor | None = None, return_lse: Literal[True] = True, enable_pdl: bool | None = None, kv_cache_sf: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None = None) Tuple[Tensor, Tensor]
Compute batch prefill/append attention between query and kv-cache stored as ragged tensor.
- Parameters:
q (torch.Tensor) – The query tensor, shape:
[qo_indptr[-1], num_qo_heads, head_dim_qk]k (torch.Tensor) – The key tensor, shape:
[kv_indptr[-1], num_kv_heads, head_dim_qk]v (torch.Tensor) – The value tensor, shape:
[kv_indptr[-1], num_kv_heads, head_dim_vo]*args – Additional arguments for the custom kernel.
q_scale (Optional[float]) – The calibration scale of fp8 query, if not provided, will be set to
1.0.k_scale (Optional[float]) – The calibration scale of fp8 or nvfp4 key, if not provided, will be set to
1.0.v_scale (Optional[float]) – The calibration scale of fp8 or nvfp4 value, if not provided, will be set to
1.0.o_scale (Optional[float]) – The calibration scale of output, if not provided, will be set to
1.0.out (Optional[torch.Tensor]) – The output tensor, if not provided, will be allocated internally.
lse (Optional[torch.Tensor]) – The log-sum-exp of attention logits, if not provided, will be allocated internally.
return_lse (bool) – Whether to return the logsumexp of attention output
enable_pdl (bool) – Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization Only supported for >= sm90, and currently only for FA2 and CUDA core decode.
kv_cache_sf (Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]) – Per-block scale factors for NVFP4 KV input. Accepts either a single packed scale tensor or a
(k_scales, v_scales)tuple matching the structure expected by the chosen backend. WhenNone(default), the kernel runs without NVFP4 KV scaling. Seeflashinfer.fp4_quantization.nvfp4_quantize()for layout details.
- Returns:
If
return_lseisFalse, the attention output, shape:[qo_indptr[-1], num_qo_heads, head_dim_vo]. Ifreturn_lseisTrue, a tuple of two tensors:The attention output, shape:
[qo_indptr[-1], num_qo_heads, head_dim_vo].The logsumexp of attention output, shape:
[qo_indptr[-1], num_qo_heads].
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
Unified BatchAttention¶
The BatchAttention class provides a holistic attention wrapper that automatically dispatches
between paged-prefill and paged-decode based on per-request sequence lengths. It is the
recommended entry point for serving stacks that batch mixed prefill/decode requests in a
single kernel launch.
- class flashinfer.attention.BatchAttention(kv_layout: str = 'NHD', device: str = 'cuda')¶
Holistic batched attention wrapper that fuses paged-prefill and paged-decode requests into a single kernel launch.
BatchAttentiondispatches between prefill-style and decode-style execution per request based on theqo_indptr/kv_indptrranges supplied toplan(), so a serving stack can submit a mixed batch (e.g. some prompts in prefill, others in decode) without splitting it into two separate wrappers. Workspace buffers are owned by the instance and reused acrossplan()/run()calls.- Parameters:
kv_layout (str) – Layout of the paged KV-cache tensors, either
"NHD"(token-major) or"HND"(head-major). Defaults to"NHD".device (str) – CUDA device that owns the internal workspace buffers, e.g.
"cuda"or"cuda:0". Defaults to"cuda".
- __init__(kv_layout: str = 'NHD', device: str = 'cuda')¶
Allocate workspace buffers and bind the wrapper to a CUDA device.
See
BatchAttentionfor the meaning of each parameter.
- plan(qo_indptr: Tensor, kv_indptr: Tensor, kv_indices: Tensor, kv_len_arr: Tensor, num_qo_heads: int, num_kv_heads: int, head_dim_qk: int, head_dim_vo: int, page_size: int, causal: bool = False, sm_scale: float = None, logits_soft_cap: float | None = None, q_data_type: dtype = torch.bfloat16, kv_data_type: dtype = torch.bfloat16, use_profiler: bool = False) None¶
Plan the holistic attention kernel for a specific batch shape.
Should be called before any
run()call. The plan is cached on the instance and reused across subsequentrun()invocations with the same layout.- Parameters:
qo_indptr (torch.Tensor) – CSR-style query offsets, shape
[batch_size + 1], dtypeint32.kv_indptr (torch.Tensor) – CSR-style page offsets into
kv_indices, shape[batch_size + 1], dtypeint32.kv_indices (torch.Tensor) – Page indices into the paged KV-cache, shape
[kv_indptr[-1]], dtypeint32.kv_len_arr (torch.Tensor) – Per-request KV-cache lengths in tokens, shape
[batch_size], dtypeint32.num_qo_heads (int) – Number of query / output heads.
num_kv_heads (int) – Number of key / value heads. Must divide
num_qo_heads.head_dim_qk (int) – Per-head dimension of the query / key tensors.
head_dim_vo (int) – Per-head dimension of the value / output tensors.
page_size (int) – Page size of the paged KV-cache.
causal (bool) – Whether to apply a causal mask. Defaults to
False.sm_scale (float) – Softmax scale. If
None, defaults to1/sqrt(head_dim_qk).logits_soft_cap (Optional[float]) – Logits soft-cap value.
Noneor0disables capping.q_data_type (torch.dtype) – Dtype of the query tensor. Defaults to
torch.bfloat16.kv_data_type (torch.dtype) – Dtype of the key / value tensors. Defaults to
torch.bfloat16.use_profiler (bool) – Whether to compile the profiler-enabled variant of the kernel. Defaults to
False.
- run(q: Tensor, kv_cache: Tensor | Tuple[Tensor, Tensor], out: Tensor | None = None, lse: Tensor | None = None, k_scale: Tensor | None = None, v_scale: Tensor | None = None, logits_soft_cap: float = 0.0, profiler_buffer: Tensor | None = None, kv_cache_sf: Tensor | Tuple[Tensor, Tensor] | None = None) Tuple[Tensor, Tensor]¶
Run the planned holistic attention kernel.
- Parameters:
q (torch.Tensor) – Query tensor, shape
[total_qo_tokens, num_qo_heads, head_dim_qk].kv_cache (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) – Either a single packed paged KV-cache tensor (when K and V share storage) or a
(k_cache, v_cache)pair. Layout must match thekv_layoutpassed to__init__().out (Optional[torch.Tensor]) – Optional output buffer. If
None, a new tensor is allocated with the same shape asq.lse (Optional[torch.Tensor]) – Optional log-sum-exp buffer, shape
[total_qo_tokens, num_qo_heads], dtypefloat32. Allocated ifNone.k_scale (Optional[torch.Tensor]) – FP8 dequantization scale for
k. Pre-multiplied intosm_scale.v_scale (Optional[torch.Tensor]) – FP8 dequantization scale for
v. Applied to the output.logits_soft_cap (float) – Logits soft-cap value. Must be consistent with the
logits_soft_cappassed toplan()(a non-zero value here requires a non-zero plan-time value too).profiler_buffer (Optional[torch.Tensor]) – Profiler buffer. Required if the wrapper was planned with
use_profiler=True.kv_cache_sf (Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]) – Optional scale tensors for NVFP4 KV-cache (one tensor or a
(k_sf, v_sf)pair, mirroring the structure ofkv_cache).
- Returns:
(out, lse)— the attention output and its log-sum-exp.- Return type:
Tuple[torch.Tensor, torch.Tensor]
- class flashinfer.attention.BatchAttentionWithAttentionSinkWrapper(float_workspace_buffer: Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, qo_indptr_buf: Tensor | None = None, paged_kv_indptr_buf: Tensor | None = None, paged_kv_indices_buf: Tensor | None = None, paged_kv_last_page_len_buf: Tensor | None = None, custom_mask_buf: Tensor | None = None, mask_indptr_buf: Tensor | None = None, backend: str = 'auto', pos_encoding_mode: str = 'NONE', use_fp16_qk_reduction: bool = False, q_data_type: dtype = torch.bfloat16, kv_data_type: dtype = torch.bfloat16, head_dim_qk: int = 128, head_dim_vo: int = 128, window_left: int = -1)¶
Wrapper for prefill and decode attention with paged KV-cache that adds support for attention sinks. This class extends BatchPrefillWithPagedKVCacheWrapper, providing a convenient interface for using attention sinks during prefill or decode attention.
- __init__(float_workspace_buffer: Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, qo_indptr_buf: Tensor | None = None, paged_kv_indptr_buf: Tensor | None = None, paged_kv_indices_buf: Tensor | None = None, paged_kv_last_page_len_buf: Tensor | None = None, custom_mask_buf: Tensor | None = None, mask_indptr_buf: Tensor | None = None, backend: str = 'auto', pos_encoding_mode: str = 'NONE', use_fp16_qk_reduction: bool = False, q_data_type: dtype = torch.bfloat16, kv_data_type: dtype = torch.bfloat16, head_dim_qk: int = 128, head_dim_vo: int = 128, window_left: int = -1) None¶
Constructor of
BatchPrefillWithPagedKVCacheWrapper.- Parameters:
float_workspace_buffer (torch.Tensor) – The user reserved workspace buffer used to store intermediate attention results in split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
kv_layout (str) – The layout of the input k/v tensors, could be either
NHDorHND.use_cuda_graph (bool) – Whether to enable CUDA graph capture for the prefill kernels, if enabled, the auxiliary data structures will be stored in provided buffers. The
batch_sizecannot change during the lifecycle of this wrapper when CUDAGraph is enabled.qo_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
qo_indptrarray, the size of the buffer should be[batch_size + 1]. This argument is only effective whenuse_cuda_graphisTrue.paged_kv_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
paged_kv_indptrarray, the size of this buffer should be[batch_size + 1]. This argument is only effective whenuse_cuda_graphisTrue.paged_kv_indices_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
paged_kv_indicesarray, should be large enough to store the maximum possible size of thepaged_kv_indicesarray during the lifetime of the wrapper. This argument is only effective whenuse_cuda_graphisTrue.paged_kv_last_page_len_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
paged_kv_last_page_lenarray, the size of the buffer should be[batch_size]. This argument is only effective whenuse_cuda_graphisTrue.custom_mask_buf (Optional[torch.Tensor]) – The user reserved buffer to store the custom mask tensor, should be large enough to store the maximum possible size of the packed custom mask tensor during the lifetime of the wrapper. This argument is only effective when
use_cuda_graphis set toTrueand the custom mask will be used in attention computation.mask_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
mask_indptrarray, the size of the buffer should be[batch_size + 1]. This argument is only effective whenuse_cuda_graphisTrueand the custom mask will be used in attention computation.backend (str) – The implementation backend, could be
auto/fa2/fa3/cudnnortrtllm-gen. Defaults toauto. If set toauto, the wrapper will automatically choose the backend based on the device architecture and kernel availability.jit_args (Optional[List[Any]]) – If provided, the wrapper will use the provided arguments to create the JIT module, otherwise, the wrapper will use default attention implementation.
jit_kwargs (Optional[Dict[str, Any]]) – The keyword arguments to create the JIT module, defaults to None.
SM120 NVFP4 Attention¶
|
Preprocess and quantize dense Q/K/V tensors for SM120 NVFP4 attention. |
|
Run SM120 NVFP4 attention on pre-quantized Q/K/V tensors. |
flashinfer.mla¶
MLA (Multi-head Latent Attention) is an attention mechanism proposed in DeepSeek series of models ( DeepSeek-V2, DeepSeek-V3, and DeepSeek-R1).
PageAttention for MLA¶
|
Decode MLA with TRTLLM-GEN, CuteDSL, XQA, or SM120/SM121 sparse kernels. |
|
XQA-backend batched MLA decode. |
- class flashinfer.mla.BatchMLAPagedAttentionWrapper(float_workspace_buffer: Tensor, use_cuda_graph: bool = False, qo_indptr: Tensor | None = None, kv_indptr: Tensor | None = None, kv_indices: Tensor | None = None, kv_len_arr: Tensor | None = None, backend: str = 'auto')¶
Wrapper class for MLA (Multi-head Latent Attention) PagedAttention on DeepSeek models. This kernel can be used in decode, and incremental prefill and should be used together with Matrix Absorption trick: where \(W_{UQ}\) is absorbed with \(W_{UK}\), and \(W_{UV}\) is absorbed with \(W_{O}\). For MLA attention without Matrix Absorption (
head_dim_qk=192andhead_dim_vo=128, which is used in prefilling self-attention stage), please useflashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper.More information about The Paged KV-Cache layout in MLA is explained in our tutorial MLA Page Layout.
For more details about the MLA computation, Matrix Absorption and FlashInfer’s MLA implementation, please refer to our blog post.
Example
>>> import torch >>> import flashinfer >>> num_local_heads = 128 >>> batch_size = 114 >>> head_dim_ckv = 512 >>> head_dim_kpe = 64 >>> page_size = 1 >>> mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( ... torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0), ... backend="fa2" ... ) >>> q_indptr = torch.arange(0, batch_size + 1).to(0).int() # for decode, each query length is 1 >>> kv_lens = torch.full((batch_size,), 999, dtype=torch.int32).to(0) >>> kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * 999 >>> kv_indices = torch.arange(0, batch_size * 999).to(0).int() >>> q_nope = torch.randn( ... batch_size * 1, num_local_heads, head_dim_ckv, dtype=torch.bfloat16, device="cuda" ... ) >>> q_pe = torch.zeros( ... batch_size * 1, num_local_heads, head_dim_kpe, dtype=torch.bfloat16, device="cuda" ... ) >>> ckv = torch.randn( ... batch_size * 999, 1, head_dim_ckv, dtype=torch.bfloat16, device="cuda" ... ) >>> kpe = torch.zeros( ... batch_size * 999, 1, head_dim_kpe, dtype=torch.bfloat16, device="cuda" ... ) >>> sm_scale = 1.0 / ((128 + 64) ** 0.5) # use head dimension before matrix absorption >>> mla_wrapper.plan( ... q_indptr, ... kv_indptr, ... kv_indices, ... kv_lens, ... num_local_heads, ... head_dim_ckv, ... head_dim_kpe, ... page_size, ... False, # causal ... sm_scale, ... q_nope.dtype, ... ckv.dtype, ... ) >>> o = mla_wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False) >>> o.shape torch.Size([114, 128, 512])
- __init__(float_workspace_buffer: Tensor, use_cuda_graph: bool = False, qo_indptr: Tensor | None = None, kv_indptr: Tensor | None = None, kv_indices: Tensor | None = None, kv_len_arr: Tensor | None = None, backend: str = 'auto') None¶
Constructor for BatchMLAPagedAttentionWrapper.
- Parameters:
float_workspace_buffer (torch.Tensor) – The user reserved workspace buffer used to store intermediate attention results in split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
use_cuda_graph (bool, optional) – Whether to enable CUDA graph capture for the prefill kernels, if enabled, the auxiliary data structures will be stored in provided buffers. The
batch_sizecannot change during the lifecycle of this wrapper when CUDAGraph is enabled.qo_indptr (Optional[torch.Tensor]) – User-reserved buffer to back the
qo_indptrarray, shape[batch_size + 1], dtypeint32. Only consulted whenuse_cuda_graph=True. The wrapper copies into this buffer atplan()time so capture-time pointers remain stable.kv_indptr (Optional[torch.Tensor]) – User-reserved buffer to back the
kv_indptrarray, shape[batch_size + 1], dtypeint32. Only consulted whenuse_cuda_graph=True.kv_indices (Optional[torch.Tensor]) – User-reserved buffer to back the
kv_indicesarray, sized to the maximum expected number of pages, dtypeint32. Only consulted whenuse_cuda_graph=True.kv_len_arr (Optional[torch.Tensor]) – User-reserved buffer to back the
kv_len_arrarray, shape[batch_size], dtypeint32. Only consulted whenuse_cuda_graph=True.backend (str) –
One of
"auto","fa2","fa3","cutlass". Default"auto"."auto"picks"fa3"on SM90a, else"fa2". On SM>=100 neither is Blackwell-native; for MLA decode prefertrtllm_batch_decode_with_kv_cache_mla(). The"cutlass"option in this wrapper is the closest in-wrapper alternative but may be slower than the fa2 fallback for decode shapes."cutlass"uses the SM100/SM110 CUTLASS MLA decode kernel. Onlyfloat_workspace_bufferis required;run()takes a different input layout (concatenatedq_nope_pe/ckv_kpe_cachepluskv_lenandpage_table).
- plan(qo_indptr: Tensor, kv_indptr: Tensor, kv_indices: Tensor, kv_len_arr: Tensor, num_heads: int, head_dim_ckv: int, head_dim_kpe: int, page_size: int, causal: bool, sm_scale: float, q_data_type: dtype, kv_data_type: dtype, use_profiler: bool = False) None¶
Plan the MLA attention computation.
- Parameters:
qo_indptr (torch.IntTensor) – The indptr of the query/output tensor, shape:
[batch_size + 1]. For decoding attention, the length of each query is 1, and the content of the tensor should be[0, 1, 2, ..., batch_size].kv_indptr (torch.IntTensor) – The indptr of the paged kv-cache, shape:
[batch_size + 1].kv_indices (torch.IntTensor) – The page indices of the paged kv-cache, shape:
[kv_indptr[-1]]or larger.kv_len_arr (torch.IntTensor) – The query length of each request, shape:
[batch_size].num_heads (int) – The number of heads in query/output tensor.
head_dim_ckv (int) – The head dimension of compressed-kv.
head_dim_kpe (int) – The head dimension for rope k-cache.
page_size (int) – The page size of the paged kv-cache.
causal (bool) – Whether to use causal attention.
sm_scale (float) – The scale factor for softmax operation.
q_data_type (torch.dtype) – The data type of the query tensor.
kv_data_type (torch.dtype) – The data type of the kv-cache tensor.
use_profiler (bool, optional) – Whether to enable intra-kernel profiler, default is False.
- run(q_nope: Tensor, q_pe: Tensor, ckv_cache: Tensor, kpe_cache: Tensor, out: Tensor | None = None, lse: Tensor | None = None, return_lse: Literal[False] = False, profiler_buffer: Tensor | None = None, kv_len: Tensor | None = None, page_table: Tensor | None = None, return_lse_base_on_e: bool = False, o_scale: float | None = None) Tensor¶
- run(q_nope: Tensor, q_pe: Tensor, ckv_cache: Tensor, kpe_cache: Tensor, out: Tensor | None = None, lse: Tensor | None = None, return_lse: Literal[True] = True, profiler_buffer: Tensor | None = None, kv_len: Tensor | None = None, page_table: Tensor | None = None, return_lse_base_on_e: bool = False, o_scale: float | None = None) Tuple[Tensor, Tensor]
Run the MLA attention computation.
- Parameters:
q_nope (torch.Tensor) – The query tensor without rope, shape:
[batch_size, num_heads, head_dim_ckv].q_pe (torch.Tensor) – The rope part of the query tensor, shape:
[batch_size, num_heads, head_dim_kpe].ckv_cache (torch.Tensor) – The compressed kv-cache tensor (without rope), shape:
[num_pages, page_size, head_dim_ckv].head_dim_ckvis 512 in DeepSeek v2/v3 models.kpe_cache (torch.Tensor) – The rope part of the kv-cache tensor, shape:
[num_pages, page_size, head_dim_kpe].head_dim_kpeis 64 in DeepSeek v2/v3 models.out (Optional[torch.Tensor]) – The output tensor, if not provided, will be allocated internally. When
o_scaleis provided, this should be an FP8 tensor.lse (Optional[torch.Tensor]) – The log-sum-exp of attention logits, if not provided, will be allocated internally.
return_lse (bool, optional) – Whether to return the log-sum-exp value, default is False.
profiler_buffer (Optional[torch.Tensor]) – The buffer to store the profiler data.
kv_len (Optional[torch.Tensor]) – The query length of each request, shape:
[batch_size]. Required whenbackendiscutlass.page_table (Optional[torch.Tensor]) – The page table of the paged kv-cache, shape:
[batch_size, num_pages]. Required whenbackendiscutlass.return_lse_base_on_e (bool, optional) – Controls the base of the returned LSE values when
return_lse=True. IfFalse(default), the LSE is returned in base-2 (log2(sum(exp2(...)))) to match the kernel’s internal log-base. IfTrue, the LSE is converted to natural-log base (log(sum(exp(...)))) for compatibility with cascade-merging APIs that expect base-e LSEs.o_scale (Optional[float]) – FP8 output dequantization scale (
real = quantized * o_scale). When provided,outmust be an FP8 tensor. Only supported with thecutlassbackend.