flashinfer.decode#
Single Request Decoding#
|
Decode attention with KV Cache for single request, return attention output. |
Batch Decoding#
- class flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer: torch.Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, use_tensor_cores: bool = False, paged_kv_indptr_buffer: torch.Tensor | None = None, paged_kv_indices_buffer: torch.Tensor | None = None, paged_kv_last_page_len_buffer: torch.Tensor | None = None)#
Wrapper class for decode attention with paged kv-cache (first proposed in vLLM) for batch of requests.
Check our tutorial for page table layout.
Examples
>>> import torch >>> import flashinfer >>> num_layers = 32 >>> num_qo_heads = 64 >>> num_kv_heads = 8 >>> head_dim = 128 >>> max_num_pages = 128 >>> page_size = 16 >>> # allocate 128MB workspace buffer >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") >>> decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( ... workspace_buffer, "NHD" ... ) >>> batch_size = 7 >>> kv_page_indices = torch.arange(max_num_pages).int().to("cuda:0") >>> kv_page_indptr = torch.tensor( ... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0" ... ) >>> # 1 <= kv_last_page_len <= page_size >>> kv_last_page_len = torch.tensor( ... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0" ... ) >>> kv_data_at_layer = [ ... torch.randn( ... max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ... ) for _ in range(num_layers) ... ] >>> # create auxiliary data structures for batch decode attention >>> decode_wrapper.begin_forward( ... kv_page_indptr, ... kv_page_indices, ... kv_last_page_len, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... page_size, ... pos_encoding_mode="NONE", ... data_type=torch.float16 ... ) >>> outputs = [] >>> for i in range(num_layers): ... q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0") ... kv_data = kv_data_at_layer[i] ... # compute batch decode attention, reuse auxiliary data structures for all layers ... o = decode_wrapper.forward(q, kv_data) ... outputs.append(o) ... >>> # clear auxiliary data structures >>> decode_wrapper.end_forward() >>> outputs[0].shape torch.Size([7, 64, 128])
Note
To accelerate computation, FlashInfer’s batch decode attention creates some auxiliary data structures, these data structures can be reused across multiple batch decode attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.
- __init__(workspace_buffer: torch.Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, use_tensor_cores: bool = False, paged_kv_indptr_buffer: torch.Tensor | None = None, paged_kv_indices_buffer: torch.Tensor | None = None, paged_kv_last_page_len_buffer: torch.Tensor | None = None)#
Constructor of
BatchDecodeWithPagedKVCacheWrapper
.- Parameters:
workspace_buffer (torch.Tensor) – The user reserved workspace buffer used to store auxiliary data structures, recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
kv_layout (str) – The layout of the input k/v tensors, could be either
NHD
orHND
.use_cuda_graph (bool) – Whether to enable CUDAGraph for batch decode attention, if enabled, the auxiliary data structures will be stored in the provided buffers. The
batch_size
cannot change during the lifecycle of this wrapper when CUDAGraph is enabled.use_tensor_cores (bool) – Whether to use tensor cores for the computation. Will be faster for large group size in grouped query attention. Defaults to
False
.indptr_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the indptr of the paged kv cache, the size of the buffer should be
[batch_size + 1]
. Only needed whenuse_cuda_graph
isTrue
.indices_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the page indices of the paged kv cache, should be large enough to store the maximum number of page indices (
max_num_pages
) during the lifecycle of this wrapper. Only needed whenuse_cuda_graph
isTrue
.last_page_len_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the number of entries in the last page, the size of the buffer should be
[batch_size]
. Only needed whenuse_cuda_graph
isTrue
.
- begin_forward(indptr: torch.Tensor, indices: torch.Tensor, last_page_len: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, page_size: int, pos_encoding_mode: str = 'NONE', logits_soft_cap: float | None = None, data_type: str | torch.dtype = 'float16', q_data_type: str | torch.dtype | None = None)#
Create auxiliary data structures for batch decode for multiple forward calls within the same decode step.
- Parameters:
indptr (torch.Tensor) – The indptr of the paged kv cache, shape:
[batch_size + 1]
indices (torch.Tensor) – The page indices of the paged kv cache, shape:
[qo_indptr[-1]]
last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv cache, shape:
[batch_size]
num_qo_heads (int) – The number of query/output heads
num_kv_heads (int) – The number of key/value heads
head_dim (int) – The dimension of the heads
page_size (int) – The page size of the paged kv cache
pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be
NONE
/ROPE_LLAMA
(LLAMA style rotary embedding) /ALIBI
. Defaults toNONE
.logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to
0
. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.data_type (Union[str, torch.dtype]) – The data type of the paged kv cache. Defaults to
float16
.q_data_type (Optional[Union[str, torch.dtype]]) – The data type of the query tensor. If None, will be set to
data_type
. Defaults toNone
.
Note
The
begin_forward()
method should be called before anyforward()
orforward_return_lse()
calls, auxiliary data structures will be created during this call and cached for multiple forward calls.The
num_qo_heads
must be a multiple ofnum_kv_heads
. Ifnum_qo_heads
is not equal tonum_kv_heads
, the function will use grouped query attention.
- end_forward()#
Clear auxiliary data structures created by
begin_forward()
.
- forward(q: torch.Tensor, paged_kv_data: torch.Tensor, pos_encoding_mode: str = 'NONE', q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, logits_soft_cap: float | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None)#
Compute batch decode attention between query and paged kv cache.
- Parameters:
q (torch.Tensor) – The query tensor, shape:
[batch_size, num_qo_heads, head_dim]
paged_kv_data (torch.Tensor) – A 5-D tensor of the reserved paged kv-cache data, shape:
[max_num_pages, 2, page_size, num_kv_heads, head_dim]
ifkv_layout
isNHD
, or[max_num_pages, 2, num_kv_heads, page_size, head_dim]
ifkv_layout
isHND
.pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be
NONE
/ROPE_LLAMA
(LLAMA style rotary embedding) /ALIBI
. Defaults toNONE
.q_scale (Optional[float]) – The calibration scale of query for fp8 input, if not provided, will be set to
1.0
.k_scale (Optional[float]) – The calibration scale of key for fp8 input, if not provided, will be set to
1.0
.v_scale (Optional[float]) – The calibration scale of value for fp8 input, if not provided, will be set to
1.0
.logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to
0
. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.sm_scale (Optional[float]) – The scale of softmax, if not provided, will be set to
1 / sqrt(head_dim)
.rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to
1.0
.rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to
1e4
.
- Returns:
The attention output, shape:
[batch_size, num_qo_heads, head_dim]
.- Return type:
torch.Tensor
- forward_return_lse(q: torch.Tensor, paged_kv_data: torch.Tensor, pos_encoding_mode: str = 'NONE', q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, logits_soft_cap: float | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None)#
Compute batch decode attention with paged kv cache, return attention output and logsumexp of attention scores.
- Parameters:
q (torch.Tensor) – The query tensor, shape:
[batch_size, num_qo_heads, head_dim]
paged_kv_data (torch.Tensor) – A 5-D tensor of the reserved paged kv-cache data, shape:
[max_num_pages, 2, page_size, num_kv_heads, head_dim]
ifkv_layout
isNHD
, or[max_num_pages, 2, num_kv_heads, page_size, head_dim]
ifkv_layout
isHND
.pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be
NONE
/ROPE_LLAMA
(LLAMA style rotary embedding) /ALIBI
. Defaults toNONE
.q_scale (Optional[float]) – The calibration scale of query for fp8 input, if not provided, will be set to
1.0
.k_scale (Optional[float]) – The calibration scale of key for fp8 input, if not provided, will be set to
1.0
.v_scale (Optional[float]) – The calibration scale of value for fp8 input, if not provided, will be set to
1.0
.logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to
0
. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.sm_scale (Optional[float]) – The scale of softmax, if not provided, will be set to
1 / sqrt(head_dim)
.rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to
1.0
.rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to
1e4
.
- Returns:
V (torch.Tensor) – The attention output, shape:
[batch_size, num_qo_heads, head_dim]
.S (torch.Tensor) – The logsumexp of attention scores, Shape:
[batch_size, num_qo_heads]
.
Notes
Please refer to the tutorial for a detailed explanation of the log-sum-exp function and attention states.
- reset_workspace_buffer(new_workspace_buffer: torch.Tensor)#
Reset the workspace buffer.
- Parameters:
new_workspace_buffer (torch.Tensor) – The new workspace buffer, the device of the new workspace buffer should be the same as the device of the input tensors.
- class flashinfer.decode.CUDAGraphBatchDecodeWithPagedKVCacheWrapper(workspace_buffer: torch.Tensor, indptr_buffer: torch.Tensor, indices_buffer: torch.Tensor, last_page_len_buffer: torch.Tensor, kv_layout: str = 'NHD', use_tensor_cores: bool = False)#
CUDAGraph-compatible Wrapper class for decode attention with paged kv-cache (first proposed in vLLM) for batch of requests.
Note that this wrapper may not be as efficient as
BatchDecodeWithPagedKVCacheWrapper
because we won’t dispatch to different kernels for different batch sizes/sequence lengths/etc to accomodate the CUDAGraph requirement.Check our tutorial for page table layout.
Note
The
begin_forward()
method could not be captured by CUDAGraph.See also
- __init__(workspace_buffer: torch.Tensor, indptr_buffer: torch.Tensor, indices_buffer: torch.Tensor, last_page_len_buffer: torch.Tensor, kv_layout: str = 'NHD', use_tensor_cores: bool = False)#
Constructor of
BatchDecodeWithPagedKVCacheWrapper
.- Parameters:
workspace_buffer (torch.Tensor) – The user reserved workspace buffer on GPU used to store auxiliary data structures, recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
indptr_buffer (torch.Tensor) – The user reserved buffer on GPU to store the indptr of the paged kv cache, should be large enough to store the indptr of maximum batch size (
[max_batch_size + 1]
) during the lifecycle of this wrapper.indices_buffer (torch.Tensor) – The user reserved buffer on GPU to store the page indices of the paged kv cache, should be large enough to store the maximum number of page indices (
max_num_pages
) during the lifecycle of this wrapper.last_page_len_buffer (torch.Tensor) – The user reserved buffer on GPU to store the number of entries in the last page, should be large enough to store the maximum batch size (
[max_batch_size]
) during the lifecycle of this wrapper.use_tensor_cores (bool) – Whether to use tensor cores for the computation. Will be faster for large group size in grouped query attention. Defaults to
False
.kv_layout (str) – The layout of the input k/v tensors, could be either
NHD
orHND
.