flashinfer.decode#

Single Request Decoding#

single_decode_with_kv_cache(q, k, v[, ...])

Decode attention with KV Cache for single request, return attention output.

Batch Decoding#

class flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer: torch.Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, use_tensor_cores: bool = False, paged_kv_indptr_buffer: torch.Tensor | None = None, paged_kv_indices_buffer: torch.Tensor | None = None, paged_kv_last_page_len_buffer: torch.Tensor | None = None)#

Wrapper class for decode attention with paged kv-cache (first proposed in vLLM) for batch of requests.

Check our tutorial for page table layout.

Examples

>>> import torch
>>> import flashinfer
>>> num_layers = 32
>>> num_qo_heads = 64
>>> num_kv_heads = 8
>>> head_dim = 128
>>> max_num_pages = 128
>>> page_size = 16
>>> # allocate 128MB workspace buffer
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
>>> decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
...     workspace_buffer, "NHD"
... )
>>> batch_size = 7
>>> kv_page_indices = torch.arange(max_num_pages).int().to("cuda:0")
>>> kv_page_indptr = torch.tensor(
...     [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0"
... )
>>> # 1 <= kv_last_page_len <= page_size
>>> kv_last_page_len = torch.tensor(
...     [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0"
... )
>>> kv_data_at_layer = [
...     torch.randn(
...         max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
...     ) for _ in range(num_layers)
... ]
>>> # create auxiliary data structures for batch decode attention
>>> decode_wrapper.begin_forward(
...     kv_page_indptr,
...     kv_page_indices,
...     kv_last_page_len,
...     num_qo_heads,
...     num_kv_heads,
...     head_dim,
...     page_size,
...     pos_encoding_mode="NONE",
...     data_type=torch.float16
... )
>>> outputs = []
>>> for i in range(num_layers):
...     q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0")
...     kv_data = kv_data_at_layer[i]
...     # compute batch decode attention, reuse auxiliary data structures for all layers
...     o = decode_wrapper.forward(q, kv_data)
...     outputs.append(o)
...
>>> # clear auxiliary data structures
>>> decode_wrapper.end_forward()
>>> outputs[0].shape
torch.Size([7, 64, 128])

Note

To accelerate computation, FlashInfer’s batch decode attention creates some auxiliary data structures, these data structures can be reused across multiple batch decode attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.

__init__(workspace_buffer: torch.Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, use_tensor_cores: bool = False, paged_kv_indptr_buffer: torch.Tensor | None = None, paged_kv_indices_buffer: torch.Tensor | None = None, paged_kv_last_page_len_buffer: torch.Tensor | None = None)#

Constructor of BatchDecodeWithPagedKVCacheWrapper.

Parameters:
  • workspace_buffer (torch.Tensor) – The user reserved workspace buffer used to store auxiliary data structures, recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.

  • kv_layout (str) – The layout of the input k/v tensors, could be either NHD or HND.

  • use_cuda_graph (bool) – Whether to enable CUDAGraph for batch decode attention, if enabled, the auxiliary data structures will be stored in the provided buffers. The batch_size cannot change during the lifecycle of this wrapper when CUDAGraph is enabled.

  • use_tensor_cores (bool) – Whether to use tensor cores for the computation. Will be faster for large group size in grouped query attention. Defaults to False.

  • indptr_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the indptr of the paged kv cache, the size of the buffer should be [batch_size + 1]. Only needed when use_cuda_graph is True.

  • indices_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the page indices of the paged kv cache, should be large enough to store the maximum number of page indices (max_num_pages) during the lifecycle of this wrapper. Only needed when use_cuda_graph is True.

  • last_page_len_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the number of entries in the last page, the size of the buffer should be [batch_size]. Only needed when use_cuda_graph is True.

begin_forward(indptr: torch.Tensor, indices: torch.Tensor, last_page_len: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, page_size: int, pos_encoding_mode: str = 'NONE', logits_soft_cap: float | None = None, data_type: str | torch.dtype = 'float16', q_data_type: str | torch.dtype | None = None)#

Create auxiliary data structures for batch decode for multiple forward calls within the same decode step.

Parameters:
  • indptr (torch.Tensor) – The indptr of the paged kv cache, shape: [batch_size + 1]

  • indices (torch.Tensor) – The page indices of the paged kv cache, shape: [qo_indptr[-1]]

  • last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv cache, shape: [batch_size]

  • num_qo_heads (int) – The number of query/output heads

  • num_kv_heads (int) – The number of key/value heads

  • head_dim (int) – The dimension of the heads

  • page_size (int) – The page size of the paged kv cache

  • pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be NONE/ROPE_LLAMA (LLAMA style rotary embedding) /ALIBI. Defaults to NONE.

  • logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to 0. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.

  • data_type (Union[str, torch.dtype]) – The data type of the paged kv cache. Defaults to float16.

  • q_data_type (Optional[Union[str, torch.dtype]]) – The data type of the query tensor. If None, will be set to data_type. Defaults to None.

Note

The begin_forward() method should be called before any forward() or forward_return_lse() calls, auxiliary data structures will be created during this call and cached for multiple forward calls.

The num_qo_heads must be a multiple of num_kv_heads. If num_qo_heads is not equal to num_kv_heads, the function will use grouped query attention.

end_forward()#

Clear auxiliary data structures created by begin_forward().

forward(q: torch.Tensor, paged_kv_data: torch.Tensor, pos_encoding_mode: str = 'NONE', q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, logits_soft_cap: float | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None)#

Compute batch decode attention between query and paged kv cache.

Parameters:
  • q (torch.Tensor) – The query tensor, shape: [batch_size, num_qo_heads, head_dim]

  • paged_kv_data (torch.Tensor) – A 5-D tensor of the reserved paged kv-cache data, shape: [max_num_pages, 2, page_size, num_kv_heads, head_dim] if kv_layout is NHD, or [max_num_pages, 2, num_kv_heads, page_size, head_dim] if kv_layout is HND.

  • pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be NONE/ROPE_LLAMA (LLAMA style rotary embedding) /ALIBI. Defaults to NONE.

  • q_scale (Optional[float]) – The calibration scale of query for fp8 input, if not provided, will be set to 1.0.

  • k_scale (Optional[float]) – The calibration scale of key for fp8 input, if not provided, will be set to 1.0.

  • v_scale (Optional[float]) – The calibration scale of value for fp8 input, if not provided, will be set to 1.0.

  • logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to 0. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.

  • sm_scale (Optional[float]) – The scale of softmax, if not provided, will be set to 1 / sqrt(head_dim).

  • rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to 1.0.

  • rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to 1e4.

Returns:

The attention output, shape: [batch_size, num_qo_heads, head_dim].

Return type:

torch.Tensor

forward_return_lse(q: torch.Tensor, paged_kv_data: torch.Tensor, pos_encoding_mode: str = 'NONE', q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, logits_soft_cap: float | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None)#

Compute batch decode attention with paged kv cache, return attention output and logsumexp of attention scores.

Parameters:
  • q (torch.Tensor) – The query tensor, shape: [batch_size, num_qo_heads, head_dim]

  • paged_kv_data (torch.Tensor) – A 5-D tensor of the reserved paged kv-cache data, shape: [max_num_pages, 2, page_size, num_kv_heads, head_dim] if kv_layout is NHD, or [max_num_pages, 2, num_kv_heads, page_size, head_dim] if kv_layout is HND.

  • pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be NONE/ROPE_LLAMA (LLAMA style rotary embedding) /ALIBI. Defaults to NONE.

  • q_scale (Optional[float]) – The calibration scale of query for fp8 input, if not provided, will be set to 1.0.

  • k_scale (Optional[float]) – The calibration scale of key for fp8 input, if not provided, will be set to 1.0.

  • v_scale (Optional[float]) – The calibration scale of value for fp8 input, if not provided, will be set to 1.0.

  • logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to 0. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.

  • sm_scale (Optional[float]) – The scale of softmax, if not provided, will be set to 1 / sqrt(head_dim).

  • rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to 1.0.

  • rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to 1e4.

Returns:

  • V (torch.Tensor) – The attention output, shape: [batch_size, num_qo_heads, head_dim].

  • S (torch.Tensor) – The logsumexp of attention scores, Shape: [batch_size, num_qo_heads].

Notes

Please refer to the tutorial for a detailed explanation of the log-sum-exp function and attention states.

reset_workspace_buffer(new_workspace_buffer: torch.Tensor)#

Reset the workspace buffer.

Parameters:

new_workspace_buffer (torch.Tensor) – The new workspace buffer, the device of the new workspace buffer should be the same as the device of the input tensors.

class flashinfer.decode.CUDAGraphBatchDecodeWithPagedKVCacheWrapper(workspace_buffer: torch.Tensor, indptr_buffer: torch.Tensor, indices_buffer: torch.Tensor, last_page_len_buffer: torch.Tensor, kv_layout: str = 'NHD', use_tensor_cores: bool = False)#

CUDAGraph-compatible Wrapper class for decode attention with paged kv-cache (first proposed in vLLM) for batch of requests.

Note that this wrapper may not be as efficient as BatchDecodeWithPagedKVCacheWrapper because we won’t dispatch to different kernels for different batch sizes/sequence lengths/etc to accomodate the CUDAGraph requirement.

Check our tutorial for page table layout.

Note

The begin_forward() method could not be captured by CUDAGraph.

__init__(workspace_buffer: torch.Tensor, indptr_buffer: torch.Tensor, indices_buffer: torch.Tensor, last_page_len_buffer: torch.Tensor, kv_layout: str = 'NHD', use_tensor_cores: bool = False)#

Constructor of BatchDecodeWithPagedKVCacheWrapper.

Parameters:
  • workspace_buffer (torch.Tensor) – The user reserved workspace buffer on GPU used to store auxiliary data structures, recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.

  • indptr_buffer (torch.Tensor) – The user reserved buffer on GPU to store the indptr of the paged kv cache, should be large enough to store the indptr of maximum batch size ([max_batch_size + 1]) during the lifecycle of this wrapper.

  • indices_buffer (torch.Tensor) – The user reserved buffer on GPU to store the page indices of the paged kv cache, should be large enough to store the maximum number of page indices (max_num_pages) during the lifecycle of this wrapper.

  • last_page_len_buffer (torch.Tensor) – The user reserved buffer on GPU to store the number of entries in the last page, should be large enough to store the maximum batch size ([max_batch_size]) during the lifecycle of this wrapper.

  • use_tensor_cores (bool) – Whether to use tensor cores for the computation. Will be faster for large group size in grouped query attention. Defaults to False.

  • kv_layout (str) – The layout of the input k/v tensors, could be either NHD or HND.