flashinfer.decode#
Single Request Decoding#
|
Decode attention with KV Cache for single request, return attention output. |
Batch Decoding#
- class flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, use_tensor_cores: bool = False, paged_kv_indptr_buffer: torch.Tensor | None = None, paged_kv_indices_buffer: torch.Tensor | None = None, paged_kv_last_page_len_buffer: torch.Tensor | None = None)#
Wrapper class for decode attention with paged kv-cache (first proposed in vLLM) for batch of requests.
Check our tutorial for page table layout.
Examples
>>> import torch >>> import flashinfer >>> num_layers = 32 >>> num_qo_heads = 64 >>> num_kv_heads = 8 >>> head_dim = 128 >>> max_num_pages = 128 >>> page_size = 16 >>> # allocate 128MB workspace buffer >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") >>> decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( ... workspace_buffer, "NHD" ... ) >>> batch_size = 7 >>> kv_page_indices = torch.arange(max_num_pages).int().to("cuda:0") >>> kv_page_indptr = torch.tensor( ... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0" ... ) >>> # 1 <= kv_last_page_len <= page_size >>> kv_last_page_len = torch.tensor( ... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0" ... ) >>> kv_cache_at_layer = [ ... torch.randn( ... max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ... ) for _ in range(num_layers) ... ] >>> # create auxiliary data structures for batch decode attention >>> decode_wrapper.plan( ... kv_page_indptr, ... kv_page_indices, ... kv_last_page_len, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... page_size, ... pos_encoding_mode="NONE", ... data_type=torch.float16 ... ) >>> outputs = [] >>> for i in range(num_layers): ... q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0") ... kv_cache = kv_cache_at_layer[i] ... # compute batch decode attention, reuse auxiliary data structures for all layers ... o = decode_wrapper.run(q, kv_cache) ... outputs.append(o) ... >>> outputs[0].shape torch.Size([7, 64, 128])
Note
To accelerate computation, FlashInfer’s batch decode attention creates some auxiliary data structures, these data structures can be reused across multiple batch decode attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.
- __init__(float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, use_tensor_cores: bool = False, paged_kv_indptr_buffer: torch.Tensor | None = None, paged_kv_indices_buffer: torch.Tensor | None = None, paged_kv_last_page_len_buffer: torch.Tensor | None = None) None #
Constructor of
BatchDecodeWithPagedKVCacheWrapper
.- Parameters:
float_workspace_buffer (torch.Tensor) – The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
kv_layout (str) – The layout of the input k/v tensors, could be either
NHD
orHND
.use_cuda_graph (bool) – Whether to enable CUDAGraph for batch decode attention, if enabled, the auxiliary data structures will be stored as the provided buffers. The
batch_size
cannot change during the lifecycle of this wrapper when CUDAGraph is enabled.use_tensor_cores (bool) – Whether to use tensor cores for the computation. Will be faster for large group size in grouped query attention. Defaults to
False
.indptr_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the indptr of the paged kv cache, the size of the buffer should be
[batch_size + 1]
. Only needed whenuse_cuda_graph
isTrue
.indices_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the page indices of the paged kv cache, should be large enough to store the maximum number of page indices (
max_num_pages
) during the lifecycle of this wrapper. Only needed whenuse_cuda_graph
isTrue
.last_page_len_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the number of entries in the last page, the size of the buffer should be
[batch_size]
. Only needed whenuse_cuda_graph
isTrue
.
- plan(indptr: torch.Tensor, indices: torch.Tensor, last_page_len: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, page_size: int, pos_encoding_mode: str = 'NONE', window_left: int = -1, logits_soft_cap: float | None = None, q_data_type: str | torch.dtype | None = 'float16', kv_data_type: str | torch.dtype | None = None, data_type: str | torch.dtype | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, non_blocking: bool = False) None #
Plan batch decode for given problem specification.
- Parameters:
indptr (torch.Tensor) – The indptr of the paged kv cache, shape:
[batch_size + 1]
indices (torch.Tensor) – The page indices of the paged kv cache, shape:
[qo_indptr[-1]]
last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv cache, shape:
[batch_size]
num_qo_heads (int) – The number of query/output heads
num_kv_heads (int) – The number of key/value heads
head_dim (int) – The dimension of the heads
page_size (int) – The page size of the paged kv cache
pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be
NONE
/ROPE_LLAMA
(LLAMA style rotary embedding) /ALIBI
. Defaults toNONE
.window_left (int) – The left (inclusive) window size for the attention window, when set to
-1
, the window size will be set to the full length of the sequence. Defaults to-1
.logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to
0
. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.q_data_type (Optional[Union[str, torch.dtype]]) – The data type of the query tensor, defaults torch.float16.
kv_data_type (Optional[Union[str, torch.dtype]]) – The data type of the key/value tensor. If None, will be set to
q_data_type
. Defaults toNone
.data_type (Optional[Union[str, torch.dtype]]) – The data type of both the query and key/value tensors. Defaults to torch.float16. data_type is deprecated, please use q_data_type and kv_data_type instead.
non_blocking (bool) – Whether to copy the input tensors to the device asynchronously, defaults to
False
. IfTrue
, user should synchronize before callingrun()
or cuda graph replay.
Note
The
plan()
method should be called before anyrun()
orrun_return_lse()
calls, auxiliary data structures will be created during this call and cached for multiple run calls.The
num_qo_heads
must be a multiple ofnum_kv_heads
. Ifnum_qo_heads
is not equal tonum_kv_heads
, the function will use grouped query attention.The
plan()
method cannot be used in Cuda Graph or intorch.compile
.
- reset_workspace_buffer(float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor) None #
Reset the workspace buffer.
- Parameters:
float_workspace_buffer (torch.Tensor) – The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors.
int_workspace_buffer (torch.Tensor) – The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors.
- run(q: torch.Tensor, paged_kv_cache: torch.Tensor | Tuple[torch.Tensor, torch.Tensor], q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, return_lse: Literal[False] = False) torch.Tensor #
- run(q: torch.Tensor, paged_kv_cache: torch.Tensor | Tuple[torch.Tensor, torch.Tensor], q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, return_lse: Literal[True] = True) Tuple[torch.Tensor, torch.Tensor]
Compute batch decode attention between query and paged kv cache.
- Parameters:
q (torch.Tensor) – The query tensor, shape:
[batch_size, num_qo_heads, head_dim]
paged_kv_cache (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) –
The paged KV-Cache stored as a tuple of tensors or a single tensor:
a tuple
(k_cache, v_cache)
of 4-D tensors, each with shape:[max_num_pages, page_size, num_kv_heads, head_dim]
ifkv_layout
isNHD
, and[max_num_pages, num_kv_heads, page_size, head_dim]
ifkv_layout
isHND
.a single 5-D tensor with shape:
[max_num_pages, 2, page_size, num_kv_heads, head_dim]
ifkv_layout
isNHD
, and[max_num_pages, 2, num_kv_heads, page_size, head_dim]
ifkv_layout
isHND
. Wherepaged_kv_cache[:, 0]
is the key-cache andpaged_kv_cache[:, 1]
is the value-cache.
q_scale (Optional[float]) – The calibration scale of query for fp8 input, if not provided, will be set to
1.0
.k_scale (Optional[float]) – The calibration scale of key for fp8 input, if not provided, will be set to
1.0
.v_scale (Optional[float]) – The calibration scale of value for fp8 input, if not provided, will be set to
1.0
.return_lse (bool) – Whether to return the logsumexp of attention scores, defaults to
False
.
- Returns:
If
return_lse
isFalse
, the attention output, shape:[batch_size, num_qo_heads, head_dim]
. Ifreturn_lse
isTrue
, a tuple of two tensors:attention output, shape:
[batch_size, num_qo_heads, head_dim]
logsumexp of attention scores, shape:
[batch_size, num_qo_heads]
.
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
- class flashinfer.decode.CUDAGraphBatchDecodeWithPagedKVCacheWrapper(workspace_buffer: torch.Tensor, indptr_buffer: torch.Tensor, indices_buffer: torch.Tensor, last_page_len_buffer: torch.Tensor, kv_layout: str = 'NHD', use_tensor_cores: bool = False)#
CUDAGraph-compatible Wrapper class for decode attention with paged kv-cache (first proposed in vLLM) for batch of requests.
Note that this wrapper may not be as efficient as
BatchDecodeWithPagedKVCacheWrapper
because we won’t dispatch to different kernels for different batch sizes/sequence lengths/etc to accomodate the CUDAGraph requirement.Check our tutorial for page table layout.
Note
The
plan()
method could not be captured by CUDAGraph.See also
- __init__(workspace_buffer: torch.Tensor, indptr_buffer: torch.Tensor, indices_buffer: torch.Tensor, last_page_len_buffer: torch.Tensor, kv_layout: str = 'NHD', use_tensor_cores: bool = False) None #
Constructor of
BatchDecodeWithPagedKVCacheWrapper
.- Parameters:
workspace_buffer (torch.Tensor) – The user reserved workspace buffer on GPU used to store auxiliary data structures, recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
indptr_buffer (torch.Tensor) – The user reserved buffer on GPU to store the indptr of the paged kv cache, should be large enough to store the indptr of maximum batch size (
[max_batch_size + 1]
) during the lifecycle of this wrapper.indices_buffer (torch.Tensor) – The user reserved buffer on GPU to store the page indices of the paged kv cache, should be large enough to store the maximum number of page indices (
max_num_pages
) during the lifecycle of this wrapper.last_page_len_buffer (torch.Tensor) – The user reserved buffer on GPU to store the number of entries in the last page, should be large enough to store the maximum batch size (
[max_batch_size]
) during the lifecycle of this wrapper.use_tensor_cores (bool) – Whether to use tensor cores for the computation. Will be faster for large group size in grouped query attention. Defaults to
False
.kv_layout (str) – The layout of the input k/v tensors, could be either
NHD
orHND
.