FlashInfer Attention Kernels¶
flashinfer.decode¶
Single Request Decoding¶
Decode attention with KV Cache for single request, return attention output. |
Batch Decoding¶
|
Performs batched decode attention with paged KV cache using cuDNN. |
|
- class flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, use_tensor_cores: bool = False, paged_kv_indptr_buffer: torch.Tensor | None = None, paged_kv_indices_buffer: torch.Tensor | None = None, paged_kv_last_page_len_buffer: torch.Tensor | None = None, backend: str = 'auto', jit_args: List[Any] | None = None)¶
Wrapper class for decode attention with paged kv-cache (first proposed in vLLM) for batch of requests.
Check our tutorial for page table layout.
Examples
>>> import torch >>> import flashinfer >>> num_layers = 32 >>> num_qo_heads = 64 >>> num_kv_heads = 8 >>> head_dim = 128 >>> max_num_pages = 128 >>> page_size = 16 >>> # allocate 128MB workspace buffer >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") >>> decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( ... workspace_buffer, "NHD" ... ) >>> batch_size = 7 >>> kv_page_indices = torch.arange(max_num_pages).int().to("cuda:0") >>> kv_page_indptr = torch.tensor( ... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0" ... ) >>> # 1 <= kv_last_page_len <= page_size >>> kv_last_page_len = torch.tensor( ... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0" ... ) >>> kv_cache_at_layer = [ ... torch.randn( ... max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ... ) for _ in range(num_layers) ... ] >>> # create auxiliary data structures for batch decode attention >>> decode_wrapper.plan( ... kv_page_indptr, ... kv_page_indices, ... kv_last_page_len, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... page_size, ... pos_encoding_mode="NONE", ... data_type=torch.float16 ... ) >>> outputs = [] >>> for i in range(num_layers): ... q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0") ... kv_cache = kv_cache_at_layer[i] ... # compute batch decode attention, reuse auxiliary data structures for all layers ... o = decode_wrapper.run(q, kv_cache) ... outputs.append(o) ... >>> outputs[0].shape torch.Size([7, 64, 128])
Note
To accelerate computation, FlashInfer’s batch decode attention creates some auxiliary data structures, these data structures can be reused across multiple batch decode attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.
- __init__(float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, use_tensor_cores: bool = False, paged_kv_indptr_buffer: torch.Tensor | None = None, paged_kv_indices_buffer: torch.Tensor | None = None, paged_kv_last_page_len_buffer: torch.Tensor | None = None, backend: str = 'auto', jit_args: List[Any] | None = None) None ¶
Constructor of
BatchDecodeWithPagedKVCacheWrapper
.- Parameters:
float_workspace_buffer (torch.Tensor) – The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
kv_layout (str) – The layout of the input k/v tensors, could be either
NHD
orHND
.use_cuda_graph (bool) – Whether to enable CUDAGraph for batch decode attention, if enabled, the auxiliary data structures will be stored as the provided buffers. The
batch_size
cannot change during the lifecycle of this wrapper when CUDAGraph is enabled.use_tensor_cores (bool) – Whether to use tensor cores for the computation. Will be faster for large group size in grouped query attention. Defaults to
False
.paged_kv_indptr_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the indptr of the paged kv cache, the size of the buffer should be
[batch_size + 1]
. Only needed whenuse_cuda_graph
isTrue
.paged_kv_indices_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the page indices of the paged kv cache, should be large enough to store the maximum number of page indices (
max_num_pages
) during the lifecycle of this wrapper. Only needed whenuse_cuda_graph
isTrue
.paged_kv_last_page_len_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the number of entries in the last page, the size of the buffer should be
[batch_size]
. Only needed whenuse_cuda_graph
isTrue
.backend (str) – The implementation backend, could be
auto
/fa2
ortrtllm-gen
. Defaults toauto
. If set toauto
, the wrapper will automatically choose the backend based on the device architecture and kernel availability.jit_args (Optional[List[Any]]) – If provided, the wrapper will use the provided arguments to create the JIT module, otherwise, the wrapper will use default attention implementation.
- plan(indptr: torch.Tensor, indices: torch.Tensor, last_page_len: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, page_size: int, pos_encoding_mode: str = 'NONE', window_left: int = -1, logits_soft_cap: float | None = None, q_data_type: str | torch.dtype | None = 'float16', kv_data_type: str | torch.dtype | None = None, data_type: str | torch.dtype | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, non_blocking: bool = True, block_tables: torch.Tensor | None = None, seq_lens: torch.Tensor | None = None) None ¶
Plan batch decode for given problem specification.
- Parameters:
indptr (torch.Tensor) – The indptr of the paged kv cache, shape:
[batch_size + 1]
indices (torch.Tensor) – The page indices of the paged kv cache, shape:
[qo_indptr[-1]]
last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv cache, shape:
[batch_size]
num_qo_heads (int) – The number of query/output heads
num_kv_heads (int) – The number of key/value heads
head_dim (int) – The dimension of the heads
page_size (int) – The page size of the paged kv cache
pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be
NONE
/ROPE_LLAMA
(LLAMA style rotary embedding) /ALIBI
. Defaults toNONE
.window_left (int) – The left (inclusive) window size for the attention window, when set to
-1
, the window size will be set to the full length of the sequence. Defaults to-1
.logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to
0
. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.q_data_type (Optional[Union[str, torch.dtype]]) – The data type of the query tensor, defaults torch.float16.
kv_data_type (Optional[Union[str, torch.dtype]]) – The data type of the key/value tensor. If None, will be set to
q_data_type
. Defaults toNone
.data_type (Optional[Union[str, torch.dtype]]) – The data type of both the query and key/value tensors. Defaults to torch.float16. data_type is deprecated, please use q_data_type and kv_data_type instead.
non_blocking (bool) – Whether to copy the input tensors to the device asynchronously, defaults to
True
.seq_lens (Optional[torch.Tensor]) – A uint32 1D tensor indicating the kv sequence length of each prompt. shape:
[batch_size]
.block_tables (Optional[torch.Tensor]) – A uint32 2D tensor indicating the block table of each prompt. shape:
[batch_size, max_num_blocks_per_seq]
.
Note
The
plan()
method should be called before anyrun()
orrun_return_lse()
calls, auxiliary data structures will be created during this call and cached for multiple run calls.The
num_qo_heads
must be a multiple ofnum_kv_heads
. Ifnum_qo_heads
is not equal tonum_kv_heads
, the function will use grouped query attention.The
plan()
method cannot be used in Cuda Graph or intorch.compile
.
- reset_workspace_buffer(float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor) None ¶
Reset the workspace buffer.
- Parameters:
float_workspace_buffer (torch.Tensor) – The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors.
int_workspace_buffer (torch.Tensor) – The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors.
- run(q: torch.Tensor, paged_kv_cache: torch.Tensor | Tuple[torch.Tensor, torch.Tensor], *args, q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, out: torch.Tensor | None = None, lse: torch.Tensor | None = None, return_lse: Literal[False] = False, enable_pdl: bool | None = None, window_left: int | None = None) torch.Tensor ¶
- run(q: torch.Tensor, paged_kv_cache: torch.Tensor | Tuple[torch.Tensor, torch.Tensor], *args, q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, out: torch.Tensor | None = None, lse: torch.Tensor | None = None, return_lse: Literal[True] = True, enable_pdl: bool | None = None, window_left: int | None = None) Tuple[torch.Tensor, torch.Tensor]
Compute batch decode attention between query and paged kv cache.
- Parameters:
q (torch.Tensor) – The query tensor, shape:
[batch_size, num_qo_heads, head_dim]
paged_kv_cache (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) –
The paged KV-Cache stored as a tuple of tensors or a single tensor:
a tuple
(k_cache, v_cache)
of 4-D tensors, each with shape:[max_num_pages, page_size, num_kv_heads, head_dim]
ifkv_layout
isNHD
, and[max_num_pages, num_kv_heads, page_size, head_dim]
ifkv_layout
isHND
.a single 5-D tensor with shape:
[max_num_pages, 2, page_size, num_kv_heads, head_dim]
ifkv_layout
isNHD
, and[max_num_pages, 2, num_kv_heads, page_size, head_dim]
ifkv_layout
isHND
. Wherepaged_kv_cache[:, 0]
is the key-cache andpaged_kv_cache[:, 1]
is the value-cache.
*args – Additional arguments for the custom kernel.
q_scale (Optional[float]) – The calibration scale of query for fp8 input, if not provided, will be set to
1.0
.k_scale (Optional[float]) – The calibration scale of key for fp8 input, if not provided, will be set to
1.0
.v_scale (Optional[float]) – The calibration scale of value for fp8 input, if not provided, will be set to
1.0
.out (Optional[torch.Tensor]) – The output tensor, if not provided, will be allocated internally.
lse (Optional[torch.Tensor]) – The log-sum-exp of attention logits, if not provided, will be allocated internally.
return_lse (bool) – Whether to return the logsumexp of attention scores, defaults to
False
.enable_pdl (bool) – Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization Only supported for >= sm90, and currently only for FA2 and CUDA core decode.
- Returns:
If
return_lse
isFalse
, the attention output, shape:[batch_size, num_qo_heads, head_dim]
. Ifreturn_lse
isTrue
, a tuple of two tensors:attention output, shape:
[batch_size, num_qo_heads, head_dim]
logsumexp of attention scores, shape:
[batch_size, num_qo_heads]
.
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
- class flashinfer.decode.CUDAGraphBatchDecodeWithPagedKVCacheWrapper(workspace_buffer: torch.Tensor, indptr_buffer: torch.Tensor, indices_buffer: torch.Tensor, last_page_len_buffer: torch.Tensor, kv_layout: str = 'NHD', use_tensor_cores: bool = False)¶
CUDAGraph-compatible Wrapper class for decode attention with paged kv-cache (first proposed in vLLM) for batch of requests.
Note that this wrapper may not be as efficient as
BatchDecodeWithPagedKVCacheWrapper
because we won’t dispatch to different kernels for different batch sizes/sequence lengths/etc to accommodate the CUDAGraph requirement.Check our tutorial for page table layout.
Note
The
plan()
method could not be captured by CUDAGraph.See also
- __init__(workspace_buffer: torch.Tensor, indptr_buffer: torch.Tensor, indices_buffer: torch.Tensor, last_page_len_buffer: torch.Tensor, kv_layout: str = 'NHD', use_tensor_cores: bool = False) None ¶
Constructor of
BatchDecodeWithPagedKVCacheWrapper
.- Parameters:
workspace_buffer (torch.Tensor) – The user reserved workspace buffer on GPU used to store auxiliary data structures, recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
indptr_buffer (torch.Tensor) – The user reserved buffer on GPU to store the indptr of the paged kv cache, should be large enough to store the indptr of maximum batch size (
[max_batch_size + 1]
) during the lifecycle of this wrapper.indices_buffer (torch.Tensor) – The user reserved buffer on GPU to store the page indices of the paged kv cache, should be large enough to store the maximum number of page indices (
max_num_pages
) during the lifecycle of this wrapper.last_page_len_buffer (torch.Tensor) – The user reserved buffer on GPU to store the number of entries in the last page, should be large enough to store the maximum batch size (
[max_batch_size]
) during the lifecycle of this wrapper.use_tensor_cores (bool) – Whether to use tensor cores for the computation. Will be faster for large group size in grouped query attention. Defaults to
False
.kv_layout (str) – The layout of the input k/v tensors, could be either
NHD
orHND
.
flashinfer.prefill¶
Attention kernels for prefill & append attention in both single request and batch serving setting.
Single Request Prefill/Append Attention¶
Prefill/Append attention with KV cache for single request, return the attention output. |
|
Prefill/Append attention with KV cache for single request, return the attention output. |
Batch Prefill/Append Attention¶
|
Performs batched prefill attention with paged KV cache using cuDNN. |
|
- class flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, qo_indptr_buf: torch.Tensor | None = None, paged_kv_indptr_buf: torch.Tensor | None = None, paged_kv_indices_buf: torch.Tensor | None = None, paged_kv_last_page_len_buf: torch.Tensor | None = None, custom_mask_buf: torch.Tensor | None = None, mask_indptr_buf: torch.Tensor | None = None, backend: str = 'auto', jit_args: List[Any] | None = None, jit_kwargs: Dict[str, Any] | None = None)¶
Wrapper class for prefill/append attention with paged kv-cache for batch of requests.
Check our tutorial for page table layout.
Example
>>> import torch >>> import flashinfer >>> num_layers = 32 >>> num_qo_heads = 64 >>> num_kv_heads = 16 >>> head_dim = 128 >>> max_num_pages = 128 >>> page_size = 16 >>> # allocate 128MB workspace buffer >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") >>> prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( ... workspace_buffer, "NHD" ... ) >>> batch_size = 7 >>> nnz_qo = 100 >>> qo_indptr = torch.tensor( ... [0, 33, 44, 55, 66, 77, 88, nnz_qo], dtype=torch.int32, device="cuda:0" ... ) >>> paged_kv_indices = torch.arange(max_num_pages).int().to("cuda:0") >>> paged_kv_indptr = torch.tensor( ... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0" ... ) >>> # 1 <= paged_kv_last_page_len <= page_size >>> paged_kv_last_page_len = torch.tensor( ... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0" ... ) >>> q_at_layer = torch.randn(num_layers, nnz_qo, num_qo_heads, head_dim).half().to("cuda:0") >>> kv_cache_at_layer = torch.randn( ... num_layers, max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ... ) >>> # create auxiliary data structures for batch prefill attention >>> prefill_wrapper.plan( ... qo_indptr, ... paged_kv_indptr, ... paged_kv_indices, ... paged_kv_last_page_len, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... page_size, ... causal=True, ... ) >>> outputs = [] >>> for i in range(num_layers): ... q = q_at_layer[i] ... kv_cache = kv_cache_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures ... o = prefill_wrapper.run(q, kv_cache) ... outputs.append(o) ... >>> outputs[0].shape torch.Size([100, 64, 128]) >>> >>> # below is another example of creating custom mask for batch prefill attention >>> mask_arr = [] >>> qo_len = (qo_indptr[1:] - qo_indptr[:-1]).cpu().tolist() >>> kv_len = (page_size * (paged_kv_indptr[1:] - paged_kv_indptr[:-1] - 1) + paged_kv_last_page_len).cpu().tolist() >>> for i in range(batch_size): ... mask_i = torch.tril( ... torch.full((qo_len[i], kv_len[i]), True, device="cuda:0"), ... diagonal=(kv_len[i] - qo_len[i]), ... ) ... mask_arr.append(mask_i.flatten()) ... >>> mask = torch.cat(mask_arr, dim=0) >>> prefill_wrapper.plan( ... qo_indptr, ... paged_kv_indptr, ... paged_kv_indices, ... paged_kv_last_page_len, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... page_size, ... custom_mask=mask, ... ) >>> for i in range(num_layers): ... q = q_at_layer[i] ... kv_cache = kv_cache_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures ... o_custom = prefill_wrapper.run(q, kv_cache) ... assert torch.allclose(o_custom, outputs[i], rtol=1e-3, atol=1e-3) ...
Note
To accelerate computation, FlashInfer’s batch prefill/append attention operators create some auxiliary data structures, these data structures can be reused across multiple prefill/append attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.
- __init__(float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, qo_indptr_buf: torch.Tensor | None = None, paged_kv_indptr_buf: torch.Tensor | None = None, paged_kv_indices_buf: torch.Tensor | None = None, paged_kv_last_page_len_buf: torch.Tensor | None = None, custom_mask_buf: torch.Tensor | None = None, mask_indptr_buf: torch.Tensor | None = None, backend: str = 'auto', jit_args: List[Any] | None = None, jit_kwargs: Dict[str, Any] | None = None) None ¶
Constructor of
BatchPrefillWithPagedKVCacheWrapper
.- Parameters:
float_workspace_buffer (torch.Tensor) – The user reserved workspace buffer used to store intermediate attention results in split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
kv_layout (str) – The layout of the input k/v tensors, could be either
NHD
orHND
.use_cuda_graph (bool) – Whether to enable CUDA graph capture for the prefill kernels, if enabled, the auxiliary data structures will be stored in provided buffers. The
batch_size
cannot change during the lifecycle of this wrapper when CUDAGraph is enabled.qo_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
qo_indptr
array, the size of the buffer should be[batch_size + 1]
. This argument is only effective whenuse_cuda_graph
isTrue
.paged_kv_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
paged_kv_indptr
array, the size of this buffer should be[batch_size + 1]
. This argument is only effective whenuse_cuda_graph
isTrue
.paged_kv_indices_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
paged_kv_indices
array, should be large enough to store the maximum possible size of thepaged_kv_indices
array during the lifetime of the wrapper. This argument is only effective whenuse_cuda_graph
isTrue
.paged_kv_last_page_len_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
paged_kv_last_page_len
array, the size of the buffer should be[batch_size]
. This argument is only effective whenuse_cuda_graph
isTrue
.custom_mask_buf (Optional[torch.Tensor]) – The user reserved buffer to store the custom mask tensor, should be large enough to store the maximum possible size of the packed custom mask tensor during the lifetime of the wrapper. This argument is only effective when
use_cuda_graph
is set toTrue
and the custom mask will be used in attention computation.mask_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
mask_indptr
array, the size of the buffer should be[batch_size + 1]
. This argument is only effective whenuse_cuda_graph
isTrue
and the custom mask will be used in attention computation.backend (str) – The implementation backend, could be
auto
/fa2
orfa3
. Defaults toauto
. If set toauto
, the wrapper will automatically choose the backend based on the device architecture and kernel availability.jit_args (Optional[List[Any]]) – If provided, the wrapper will use the provided arguments to create the JIT module, otherwise, the wrapper will use default attention implementation.
jit_kwargs (Optional[Dict[str, Any]]) – The keyword arguments to create the JIT module, defaults to None.
- plan(qo_indptr: torch.Tensor, paged_kv_indptr: torch.Tensor, paged_kv_indices: torch.Tensor, paged_kv_last_page_len: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim_qk: int, page_size: int, head_dim_vo: int | None = None, custom_mask: torch.Tensor | None = None, packed_custom_mask: torch.Tensor | None = None, causal: bool = False, pos_encoding_mode: str = 'NONE', use_fp16_qk_reduction: bool = False, sm_scale: float | None = None, window_left: int = -1, logits_soft_cap: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, q_data_type: str | torch.dtype = 'float16', kv_data_type: str | torch.dtype | None = None, non_blocking: bool = True, prefix_len_ptr: torch.Tensor | None = None, token_pos_in_items_ptr: torch.Tensor | None = None, token_pos_in_items_len: int = 0, max_item_len_ptr: torch.Tensor | None = None, seq_lens: torch.Tensor | None = None, block_tables: torch.Tensor | None = None) None ¶
Plan batch prefill/append attention on Paged KV-Cache for given problem specification.
- Parameters:
qo_indptr (torch.Tensor) – The indptr of the query/output tensor, shape:
[batch_size + 1]
.paged_kv_indptr (torch.Tensor) – The indptr of the paged kv-cache, shape:
[batch_size + 1]
.paged_kv_indices (torch.Tensor) – The page indices of the paged kv-cache, shape:
[qo_indptr[-1]]
.paged_kv_last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv-cache, shape:
[batch_size]
.num_qo_heads (int) – The number of query/output heads.
num_kv_heads (int) – The number of key/value heads.
head_dim_qk (int) – The dimension of the query/key heads.
page_size (int) – The size of each page in the paged kv-cache.
head_dim_vo (Optional[int]) – The dimension of the value/output heads, if not provided, will be set to
head_dim_qk
.custom_mask (Optional[torch.Tensor]) –
The flattened boolean mask tensor, shape:
(sum(q_len[i] * k_len[i] for i in range(batch_size))
. The elements in the mask tensor should be eitherTrue
orFalse
, whereFalse
means the corresponding element in the attention matrix will be masked out.Please refer to the mask layout for more details about flattened layout of mask tensor.
When
custom_mask
is provided, andpacked_custom_mask
is not, the function will pack the custom mask tensor into a 1D packed mask tensor, which introduces additional overhead.packed_custom_mask (Optional[torch.Tensor]) – The 1D packed uint8 mask tensor, if provided, the
custom_mask
will be ignored. The packed mask tensor is generated byflashinfer.quantization.packbits()
.causal (bool) – Whether to apply causal mask to the attention matrix. This is only effective when
custom_mask
is not provided inplan()
.pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be
NONE
/ROPE_LLAMA
(LLAMA style rotary embedding) /ALIBI
. Default isNONE
.use_fp16_qk_reduction (bool) – Whether to use f16 for qk reduction (faster at the cost of slight precision loss).
window_left (int) – The left (inclusive) window size for the attention window, when set to
-1
, the window size will be set to the full length of the sequence. Defaults to-1
.logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to
0
. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.sm_scale (Optional[float]) – The scale used in softmax, if not provided, will be set to
1.0 / sqrt(head_dim)
.rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to
1.0
.rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to
1e4
.q_data_type (Union[str, torch.dtype]) – The data type of the query tensor, defaults torch.float16.
kv_data_type (Optional[Union[str, torch.dtype]]) – The data type of the key/value tensor. If None, will be set to
q_data_type
.non_blocking (bool) – Whether to copy the input tensors to the device asynchronously, defaults to
True
.prefix_len_ptr (Optional[torch.Tensor]) – prefix length. A uint32 1D tensor indicating the prefix length of each prompt. The tensor size is equal to the batch size.
token_pos_in_items_ptr (Optional[float]) – A uint16 1D tensor (it will be converted to uint16 in flashinfer) indicating the token position of each item and started from 0 (delimiter) for each item. E.g., if we have 3 items of length 3, 2, 4 respectively for this member. This vector will be looking like [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0] with 4 delimiters indexed as 0. For batch size > 1, we will concat them as 1D with zero paddings to make sure each has the same length, the padding length is defined by token_pos_in_items_len - length of the raw token_pos_in_items_ptr for each prompt.
token_pos_in_items_len (int) – zero padding length for token_pos_in_items_ptr to better handle the bsz > 1 case. Still using the above 3,2,4 example. If we set token_pos_in_items_len to be 20, it will be [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0] with 7 padded zeros. (note there’re 8 zeros in the end where the first one is the delimiter token 0 in the end of the prompt)
max_item_len_ptr (Optional[float]) – a uint16 vector contains the max token length of all items for each prompt
seq_lens (Optional[torch.Tensor]) – A uint32 1D tensor indicating the kv sequence length of each prompt. shape:
[batch_size]
.block_tables (Optional[torch.Tensor]) – A uint32 2D tensor indicating the block table of each prompt. shape:
[batch_size, max_num_blocks_per_seq]
.
Note
The
plan()
method should be called before anyrun()
orrun_return_lse()
calls, auxiliary data structures will be created during this call and cached for multiple kernel runs.The
num_qo_heads
must be a multiple ofnum_kv_heads
. Ifnum_qo_heads
is not equal tonum_kv_heads
, the function will use grouped query attention.The
plan()
method cannot be used in Cuda Graph or intorch.compile
.
- reset_workspace_buffer(float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor) None ¶
Reset the workspace buffer.
- Parameters:
float_workspace_buffer (torch.Tensor) – The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors.
int_workspace_buffer (torch.Tensor) – The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors.
- run(q: torch.Tensor, paged_kv_cache: torch.Tensor | Tuple[torch.Tensor, torch.Tensor], *args, k_scale: float | None = None, v_scale: float | None = None, out: torch.Tensor | None = None, lse: torch.Tensor | None = None, return_lse: Literal[False] = False, enable_pdl: bool | None = None, window_left: int | None = None) torch.Tensor ¶
- run(q: torch.Tensor, paged_kv_cache: torch.Tensor | Tuple[torch.Tensor, torch.Tensor], *args, k_scale: float | None = None, v_scale: float | None = None, out: torch.Tensor | None = None, lse: torch.Tensor | None = None, return_lse: Literal[True] = True, enable_pdl: bool | None = None, window_left: int | None = None) Tuple[torch.Tensor, torch.Tensor]
Compute batch prefill/append attention between query and paged kv-cache.
- Parameters:
q (torch.Tensor) – The query tensor, shape:
[qo_indptr[-1], num_qo_heads, head_dim]
paged_kv_cache (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) –
The paged KV-Cache stored as a tuple of tensors or a single tensor:
a tuple
(k_cache, v_cache)
of 4-D tensors, each with shape:[max_num_pages, page_size, num_kv_heads, head_dim]
ifkv_layout
isNHD
, and[max_num_pages, num_kv_heads, page_size, head_dim]
ifkv_layout
isHND
.a single 5-D tensor with shape:
[max_num_pages, 2, page_size, num_kv_heads, head_dim]
ifkv_layout
isNHD
, and[max_num_pages, 2, num_kv_heads, page_size, head_dim]
ifkv_layout
isHND
. Wherepaged_kv_cache[:, 0]
is the key-cache andpaged_kv_cache[:, 1]
is the value-cache.
*args – Additional arguments for custom kernels.
k_scale (Optional[float]) – The calibration scale of key for fp8 input, if not provided, will be set to
1.0
.v_scale (Optional[float]) – The calibration scale of value for fp8 input, if not provided, will be set to
1.0
.out (Optional[torch.Tensor]) – The output tensor, if not provided, will be allocated internally.
lse (Optional[torch.Tensor]) – The log-sum-exp of attention logits, if not provided, will be allocated internally.
return_lse (bool) – Whether to return the logsumexp of attention output
enable_pdl (bool) – Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization Only supported for >= sm90, and currently only for FA2 and CUDA core decode.
- Returns:
If
return_lse
isFalse
, the attention output, shape:[qo_indptr[-1], num_qo_heads, head_dim]
. Ifreturn_lse
isTrue
, a tuple of two tensors:The attention output, shape:
[qo_indptr[-1], num_qo_heads, head_dim]
.The logsumexp of attention output, shape:
[qo_indptr[-1], num_qo_heads]
.
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
- class flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, qo_indptr_buf: torch.Tensor | None = None, kv_indptr_buf: torch.Tensor | None = None, custom_mask_buf: torch.Tensor | None = None, mask_indptr_buf: torch.Tensor | None = None, backend: str = 'auto', jit_args: List[Any] | None = None, jit_kwargs: Dict[str, Any] | None = None)¶
Wrapper class for prefill/append attention with ragged (tensor) kv-cache for batch of requests.
Check our tutorial for ragged kv-cache layout.
Example
>>> import torch >>> import flashinfer >>> num_layers = 32 >>> num_qo_heads = 64 >>> num_kv_heads = 16 >>> head_dim = 128 >>> # allocate 128MB workspace buffer >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") >>> prefill_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( ... workspace_buffer, "NHD" ... ) >>> batch_size = 7 >>> nnz_kv = 100 >>> nnz_qo = 100 >>> qo_indptr = torch.tensor( ... [0, 33, 44, 55, 66, 77, 88, nnz_qo], dtype=torch.int32, device="cuda:0" ... ) >>> kv_indptr = qo_indptr.clone() >>> q_at_layer = torch.randn(num_layers, nnz_qo, num_qo_heads, head_dim).half().to("cuda:0") >>> k_at_layer = torch.randn(num_layers, nnz_kv, num_kv_heads, head_dim).half().to("cuda:0") >>> v_at_layer = torch.randn(num_layers, nnz_kv, num_kv_heads, head_dim).half().to("cuda:0") >>> # create auxiliary data structures for batch prefill attention >>> prefill_wrapper.plan( ... qo_indptr, ... kv_indptr, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... causal=True, ... ) >>> outputs = [] >>> for i in range(num_layers): ... q = q_at_layer[i] ... k = k_at_layer[i] ... v = v_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures ... o = prefill_wrapper.run(q, k, v) ... outputs.append(o) ... >>> outputs[0].shape torch.Size([100, 64, 128]) >>> >>> # below is another example of creating custom mask for batch prefill attention >>> mask_arr = [] >>> qo_len = (qo_indptr[1:] - qo_indptr[:-1]).cpu().tolist() >>> kv_len = (kv_indptr[1:] - kv_indptr[:-1]).cpu().tolist() >>> for i in range(batch_size): ... mask_i = torch.tril( ... torch.full((qo_len[i], kv_len[i]), True, device="cuda:0"), ... diagonal=(kv_len[i] - qo_len[i]), ... ) ... mask_arr.append(mask_i.flatten()) ... >>> mask = torch.cat(mask_arr, dim=0) >>> prefill_wrapper.plan( ... qo_indptr, ... kv_indptr, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... custom_mask=mask ... ) >>> outputs_custom_mask = [] >>> for i in range(num_layers): ... q = q_at_layer[i] ... k = k_at_layer[i] ... v = v_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures ... o_custom = prefill_wrapper.run(q, k, v) ... assert torch.allclose(o_custom, outputs[i], rtol=1e-3, atol=1e-3) ... >>> outputs_custom_mask[0].shape torch.Size([100, 64, 128])
Note
To accelerate computation, FlashInfer’s batch prefill/append attention operators create some auxiliary data structures, these data structures can be reused across multiple prefill/append attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.
- __init__(float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, qo_indptr_buf: torch.Tensor | None = None, kv_indptr_buf: torch.Tensor | None = None, custom_mask_buf: torch.Tensor | None = None, mask_indptr_buf: torch.Tensor | None = None, backend: str = 'auto', jit_args: List[Any] | None = None, jit_kwargs: Dict[str, Any] | None = None) None ¶
Constructor of
BatchPrefillWithRaggedKVCacheWrapper
.- Parameters:
float_workspace_buffer (torch.Tensor) – The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
kv_layout (str) – The layout of the input k/v tensors, could be either
NHD
orHND
.use_cuda_graph (bool) – Whether to enable CUDA graph capture for the prefill kernels, if enabled, the auxiliary data structures will be stored as the provided buffers.
qo_indptr_buf (Optional[torch.Tensor]) – The user reserved GPU buffer to store the
qo_indptr
array, the size of the buffer should be[batch_size + 1]
. This argument is only effective whenuse_cuda_graph
isTrue
.kv_indptr_buf (Optional[torch.Tensor]) – The user reserved GPU buffer to store the
kv_indptr
array, the size of the buffer should be[batch_size + 1]
. This argument is only effective whenuse_cuda_graph
isTrue
.custom_mask_buf (Optional[torch.Tensor]) – The user reserved GPU buffer to store the custom mask tensor, should be large enough to store the maximum possible size of the packed custom mask tensor during the lifetime of the wrapper. This argument is only effective when
use_cuda_graph
isTrue
and custom mask will be used in attention computation.mask_indptr_buf (Optional[torch.Tensor]) – The user reserved GPU buffer to store the
mask_indptr
array, the size of the buffer should be[batch_size]
. This argument is only effective whenuse_cuda_graph
isTrue
and custom mask will be used in attention computation.backend (str) – The implementation backend, could be
auto
/fa2
/fa3
ortrtllm-gen
. Defaults toauto
. If set toauto
, the wrapper will automatically choose the backend based on the device architecture and kernel availability.jit_args (Optional[List[Any]]) – If provided, the wrapper will use the provided arguments to create the JIT module, otherwise, the wrapper will use default attention implementation.
jit_kwargs (Optional[Dict[str, Any]]) – The keyword arguments to create the JIT module, defaults to None.
- plan(qo_indptr: torch.Tensor, kv_indptr: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim_qk: int, head_dim_vo: int | None = None, custom_mask: torch.Tensor | None = None, packed_custom_mask: torch.Tensor | None = None, causal: bool = False, pos_encoding_mode: str = 'NONE', use_fp16_qk_reduction: bool = False, window_left: int = -1, logits_soft_cap: float | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, q_data_type: str | torch.dtype = 'float16', kv_data_type: str | torch.dtype | None = None, non_blocking: bool = True, prefix_len_ptr: torch.Tensor | None = None, token_pos_in_items_ptr: torch.Tensor | None = None, token_pos_in_items_len: int = 0, max_item_len_ptr: torch.Tensor | None = None) None ¶
Plan batch prefill/append attention on Ragged KV-Cache for given problem specification.
- Parameters:
qo_indptr (torch.Tensor) – The indptr of the query/output tensor, shape:
[batch_size + 1]
.kv_indptr (torch.Tensor) – The indptr of the key/value tensor, shape:
[batch_size + 1]
.num_qo_heads (int) – The number of query/output heads.
num_kv_heads (int) – The number of key/value heads.
head_dim_qk (int) – The dimension of the heads on query/key tensor.
head_dim_vo (Optional[int]) – The dimension of the heads on value/output tensor. If not provided, will be set to
head_dim_vo
.custom_mask (Optional[torch.Tensor]) –
The flattened boolean mask tensor, shape:
(sum(q_len[i] * k_len[i] for i in range(batch_size))
. The elements in the mask tensor should be eitherTrue
orFalse
, whereFalse
means the corresponding element in the attention matrix will be masked out.Please refer to the mask layout for more details about flattened layout of mask tensor.
When
custom_mask
is provided, andpacked_custom_mask
is not, the function will pack the custom mask tensor into a 1D packed mask tensor, which introduces additional overhead.packed_custom_mask (Optional[torch.Tensor]) –
The 1D packed uint8 mask tensor, if provided, the
custom_mask
will be ignored. The packed mask tensor is generated byflashinfer.quantization.packbits()
.If provided, the custom mask will be added to the attention matrix before softmax and after scaling. The mask tensor should be in the same device as the input tensors.
causal (bool) – Whether to apply causal mask to the attention matrix. This argument is ignored if
mask
is provided inplan()
.pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be
NONE
/ROPE_LLAMA
(LLAMA style rotary embedding) /ALIBI
. Default isNONE
.use_fp16_qk_reduction (bool) – Whether to use f16 for qk reduction (faster at the cost of slight precision loss).
window_left (int) – The left (inclusive) window size for the attention window, when set to
-1
, the window size will be set to the full length of the sequence. Defaults to-1
.logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to
0
. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.sm_scale (Optional[float]) – The scale used in softmax, if not provided, will be set to
1.0 / sqrt(head_dim_qk)
.rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to
1.0
.rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to
1e4
.q_data_type (Union[str, torch.dtype]) – The data type of the query tensor, defaults to torch.float16.
kv_data_type (Optional[Union[str, torch.dtype]]) – The data type of the key/value tensor. If None, will be set to
q_data_type
.non_blocking (bool) – Whether to copy the input tensors to the device asynchronously, defaults to
True
.prefix_len_ptr (Optional[torch.Tensor]) – prefix length. A uint32 1D tensor indicating the prefix length of each prompt. The tensor size is equal to the batch size.
token_pos_in_items_ptr (Optional[float]) – A uint16 1D tensor (it will be converted to uint16 in flashinfer) indicating the token position of each item and started from 0 (delimiter) for each item. E.g., if we have 3 items of length 3, 2, 4 respectively for this member. This vector will be looking like [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0] with 4 delimiters indexed as 0. For batch size > 1, we will concat them as 1D with zero paddings to make sure each has the same length, the padding length is defined by token_pos_in_items_len - length of the raw token_pos_in_items_ptr for each prompt.
token_pos_in_items_len (int) – zero padding length for token_pos_in_items_ptr to better handle the bsz > 1 case. Still using the above 3,2,4 example. If we set token_pos_in_items_len to be 20, it will be [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0] with 7 padded zeros. (note there’re 8 zeros in the end where the first one is the delimiter token 0 in the end of the prompt)
max_item_len_ptr (Optional[float]) – a uint16 vector contains the max token length of all items for each prompt
Note
The
plan()
method should be called before anyrun()
orrun_return_lse()
calls, auxiliary data structures will be created during this plan call and cached for multiple kernel runs.The
num_qo_heads
must be a multiple ofnum_kv_heads
. Ifnum_qo_heads
is not equal tonum_kv_heads
, the function will use grouped query attention.The
plan()
method cannot be used in Cuda Graph or intorch.compile
.
- reset_workspace_buffer(float_workspace_buffer: torch.Tensor, int_workspace_buffer) None ¶
Reset the workspace buffer.
- Parameters:
float_workspace_buffer (torch.Tensor) – The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors.
int_workspace_buffer (torch.Tensor) – The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors.
- run(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, out: torch.Tensor | None = None, lse: torch.Tensor | None = None, return_lse: Literal[False] = False, enable_pdl: bool | None = None) torch.Tensor ¶
- run(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, out: torch.Tensor | None = None, lse: torch.Tensor | None = None, return_lse: Literal[True] = True, enable_pdl: bool | None = None) Tuple[torch.Tensor, torch.Tensor]
Compute batch prefill/append attention between query and kv-cache stored as ragged tensor.
- Parameters:
q (torch.Tensor) – The query tensor, shape:
[qo_indptr[-1], num_qo_heads, head_dim_qk]
k (torch.Tensor) – The key tensor, shape:
[kv_indptr[-1], num_kv_heads, head_dim_qk]
v (torch.Tensor) – The value tensor, shape:
[kv_indptr[-1], num_kv_heads, head_dim_vo]
*args – Additional arguments for the custom kernel.
out (Optional[torch.Tensor]) – The output tensor, if not provided, will be allocated internally.
lse (Optional[torch.Tensor]) – The log-sum-exp of attention logits, if not provided, will be allocated internally.
return_lse (bool) – Whether to return the logsumexp of attention output
enable_pdl (bool) – Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization Only supported for >= sm90, and currently only for FA2 and CUDA core decode.
- Returns:
If
return_lse
isFalse
, the attention output, shape:[qo_indptr[-1], num_qo_heads, head_dim_vo]
. Ifreturn_lse
isTrue
, a tuple of two tensors:The attention output, shape:
[qo_indptr[-1], num_qo_heads, head_dim_vo]
.The logsumexp of attention output, shape:
[qo_indptr[-1], num_qo_heads]
.
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
flashinfer.mla¶
MLA (Multi-head Latent Attention) is an attention mechanism proposed in DeepSeek series of models ( DeepSeek-V2, DeepSeek-V3, and DeepSeek-R1).
PageAttention for MLA¶
- class flashinfer.mla.BatchMLAPagedAttentionWrapper(float_workspace_buffer: torch.Tensor, use_cuda_graph: bool = False, qo_indptr: torch.Tensor | None = None, kv_indptr: torch.Tensor | None = None, kv_indices: torch.Tensor | None = None, kv_len_arr: torch.Tensor | None = None, backend: str = 'auto')¶
Wrapper class for MLA (Multi-head Latent Attention) PagedAttention on DeepSeek models. This kernel can be used in decode, and incremental prefill and should be used together with Matrix Absorption trick: where \(W_{UQ}\) is absorbed with \(W_{UK}\), and \(W_{UV}\) is absorbed with \(W_{O}\). For MLA attention without Matrix Absorption (
head_dim_qk=192
andhead_dim_vo=128
, which is used in prefilling self-attention stage), please useflashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper
.More information about The Paged KV-Cache layout in MLA is explained in our tutorial MLA Page Layout.
For more details about the MLA computation, Matrix Absorption and FlashInfer’s MLA implementation, please refer to our blog post.
Example
>>> import torch >>> import flashinfer >>> num_local_heads = 128 >>> batch_size = 114 >>> head_dim_ckv = 512 >>> head_dim_kpe = 64 >>> page_size = 1 >>> mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( ... torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0), ... backend="fa2" ... ) >>> q_indptr = torch.arange(0, batch_size + 1).to(0).int() # for decode, each query length is 1 >>> kv_lens = torch.full((batch_size,), 999, dtype=torch.int32).to(0) >>> kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * 999 >>> kv_indices = torch.arange(0, batch_size * 999).to(0).int() >>> q_nope = torch.randn( ... batch_size * 1, num_local_heads, head_dim_ckv, dtype=torch.bfloat16, device="cuda" ... ) >>> q_pe = torch.zeros( ... batch_size * 1, num_local_heads, head_dim_kpe, dtype=torch.bfloat16, device="cuda" ... ) >>> ckv = torch.randn( ... batch_size * 999, 1, head_dim_ckv, dtype=torch.bfloat16, device="cuda" ... ) >>> kpe = torch.zeros( ... batch_size * 999, 1, head_dim_kpe, dtype=torch.bfloat16, device="cuda" ... ) >>> sm_scale = 1.0 / ((128 + 64) ** 0.5) # use head dimension before matrix absorption >>> mla_wrapper.plan( ... q_indptr, ... kv_indptr, ... kv_indices, ... kv_lens, ... num_local_heads, ... head_dim_ckv, ... head_dim_kpe, ... page_size, ... False, # causal ... sm_scale, ... q_nope.dtype, ... ckv.dtype, ... ) >>> o = mla_wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False) >>> o.shape torch.Size([114, 128, 512])
- __init__(float_workspace_buffer: torch.Tensor, use_cuda_graph: bool = False, qo_indptr: torch.Tensor | None = None, kv_indptr: torch.Tensor | None = None, kv_indices: torch.Tensor | None = None, kv_len_arr: torch.Tensor | None = None, backend: str = 'auto') None ¶
Constructor for BatchMLAPagedAttentionWrapper.
- Parameters:
float_workspace_buffer (torch.Tensor) – The user reserved workspace buffer used to store intermediate attention results in split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
use_cuda_graph (bool, optional) – Whether to enable CUDA graph capture for the prefill kernels, if enabled, the auxiliary data structures will be stored in provided buffers. The
batch_size
cannot change during the lifecycle of this wrapper when CUDAGraph is enabled.qo_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
qo_indptr
array, the size of the buffer should be[batch_size + 1]
. This argument is only effective whenuse_cuda_graph
isTrue
.kv_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
kv_indptr
array, the size of the buffer should be[batch_size + 1]
. This argument is only effective whenuse_cuda_graph
isTrue
.kv_indices_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
kv_indices
array. This argument is only effective whenuse_cuda_graph
isTrue
.kv_len_arr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
kv_len_arr
array, the size of the buffer should be[batch_size]
. This argument is only effective whenuse_cuda_graph
isTrue
.backend (str) – The implementation backend, could be
auto
/fa2
orfa3
. Defaults toauto
. If set toauto
, the function will automatically choose the backend based on the device architecture and kernel availability. Ifcutlass
is provided, the MLA kernels will be generated by CUTLASS and only float_workspace_buffer is required and other arguments are ignored.
- plan(qo_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, kv_len_arr: torch.Tensor, num_heads: int, head_dim_ckv: int, head_dim_kpe: int, page_size: int, causal: bool, sm_scale: float, q_data_type: torch.dtype, kv_data_type: torch.dtype, use_profiler: bool = False) None ¶
Plan the MLA attention computation.
- Parameters:
qo_indptr (torch.Tensor) – The indptr of the query/output tensor, shape:
[batch_size + 1]
. For decoding attention, the length of each query is 1, and the content of the tensor should be[0, 1, 2, ..., batch_size]
.kv_indptr (torch.Tensor) – The indptr of the paged kv-cache, shape:
[batch_size + 1]
.kv_indices (torch.Tensor) – The page indices of the paged kv-cache, shape:
[kv_indptr[-1]]
or larger.kv_len_arr (torch.Tensor) – The query length of each request, shape:
[batch_size]
.num_heads (int) – The number of heads in query/output tensor.
head_dim_ckv (int) – The head dimension of compressed-kv.
head_dim_kpe (int) – The head dimension for rope k-cache.
page_size (int) – The page size of the paged kv-cache.
causal (bool) – Whether to use causal attention.
sm_scale (float) – The scale factor for softmax operation.
q_data_type (torch.dtype) – The data type of the query tensor.
kv_data_type (torch.dtype) – The data type of the kv-cache tensor.
use_profiler (bool, optional) – Whether to enable intra-kernel profiler, default is False.
- run(q_nope: torch.Tensor, q_pe: torch.Tensor, ckv_cache: torch.Tensor, kpe_cache: torch.Tensor, out: torch.Tensor | None = None, lse: torch.Tensor | None = None, return_lse: Literal[False] = False, profiler_buffer: torch.Tensor | None = None, kv_len: torch.Tensor | None = None, page_table: torch.Tensor | None = None) torch.Tensor ¶
- run(q_nope: torch.Tensor, q_pe: torch.Tensor, ckv_cache: torch.Tensor, kpe_cache: torch.Tensor, out: torch.Tensor | None = None, lse: torch.Tensor | None = None, return_lse: Literal[True] = True, profiler_buffer: torch.Tensor | None = None, kv_len: torch.Tensor | None = None, page_table: torch.Tensor | None = None) Tuple[torch.Tensor, torch.Tensor]
Run the MLA attention computation.
- Parameters:
q_nope (torch.Tensor) – The query tensor without rope, shape:
[batch_size, num_heads, head_dim_ckv]
.q_pe (torch.Tensor) – The rope part of the query tensor, shape:
[batch_size, num_heads, head_dim_kpe]
.ckv_cache (torch.Tensor) – The compressed kv-cache tensor (without rope), shape:
[num_pages, page_size, head_dim_ckv]
.head_dim_ckv
is 512 in DeepSeek v2/v3 models.kpe_cache (torch.Tensor) – The rope part of the kv-cache tensor, shape:
[num_pages, page_size, head_dim_kpe]
.head_dim_kpe
is 64 in DeepSeek v2/v3 models.out (Optional[torch.Tensor]) – The output tensor, if not provided, will be allocated internally.
lse (Optional[torch.Tensor]) – The log-sum-exp of attention logits, if not provided, will be allocated internally.
return_lse (bool, optional) – Whether to return the log-sum-exp value, default is False.
profiler_buffer (Optional[torch.Tensor]) – The buffer to store the profiler data.
kv_len (Optional[torch.Tensor]) – The query length of each request, shape:
[batch_size]
. Required whenbackend
iscutlass
.page_table (Optional[torch.Tensor]) – The page table of the paged kv-cache, shape:
[batch_size, num_pages]
. Required whenbackend
iscutlass
.