flashinfer.cascade#

Merge Attention States#

merge_state(v_a, s_a, v_b, s_b)

Merge the attention output V and the logsumexp value S from the two KV-segments.

merge_state_in_place(v, s, v_other, s_other)

Merge the self-attention state (v, s) with another state (v_other, s_other) in-place.

merge_states(v, s)

Merge multiple attention states (v, s).

Cascade Attention#

Cascade Attention Wrapper Classes#

class flashinfer.cascade.MultiLevelCascadeAttentionWrapper(num_levels, float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD')#

Attention wrapper for memory efficient multi-level cascade inference, this API assumes all levels KV-Cache are stored in a unified paged table.

Please check Multi-level Cascade Inference Data Layout for data layout in cascade inference. Note that it’s not always beneficial to increase the number of levels because of the overhead of merging attention results.

The idea of cascade inference is introduced in our blog post.

Example

>>> import torch
>>> import flashinfer
>>> num_layers = 32
>>> num_qo_heads = 64
>>> num_kv_heads = 8
>>> head_dim = 128
>>> page_size = 16
>>> # allocate 128MB workspace buffer
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
>>> wrapper = flashinfer.MultiLevelCascadeAttentionWrapper(
...     2, workspace_buffer, "NHD"
... )
>>> batch_size = 7
>>> shared_kv_num_pages = 512
>>> unique_kv_num_pages = 128
>>> total_num_pages = shared_kv_num_pages + unique_kv_num_pages
>>> shared_kv_page_indices = torch.arange(shared_kv_num_pages).int().to("cuda:0")
>>> shared_kv_page_indptr = torch.tensor([0, shared_kv_num_pages], dtype=torch.int32, device="cuda:0")
>>> unique_kv_page_indices = torch.arange(shared_kv_num_pages, total_num_pages).int().to("cuda:0")
>>> unique_kv_page_indptr = torch.tensor(
...     [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0"
... )
>>> shared_kv_last_page_len = torch.tensor([page_size], dtype=torch.int32, device="cuda:0")
>>> # 1 <= kv_last_page_len <= page_size
>>> unique_kv_last_page_len = torch.tensor(
...     [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0"
... )
>>> kv_cache_at_layer = [
...     torch.randn(
...         total_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
...     ) for _ in range(num_layers)
... ]
>>> qo_indptr_arr = [
...     torch.tensor([0, batch_size], dtype=torch.int32, device="cuda:0"),  # top-level for shared KV-Cache
...     torch.arange(batch_size + 1, dtype=torch.int32, device="cuda:0")    # bottom-level for unique KV-Cache
... ]
>>> # create auxiliary data structures for batch decode attention
>>> wrapper.plan(
...     qo_indptr_arr,
...     [shared_kv_page_indptr, unique_kv_page_indptr],
...     [shared_kv_page_indices, unique_kv_page_indices],
...     [shared_kv_last_page_len, unique_kv_last_page_len],
...     num_qo_heads,
...     num_kv_heads,
...     head_dim,
...     page_size,
... )
>>> outputs = []
>>> for i in range(num_layers):
...     q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0")
...     # compute batch decode attention, reuse auxiliary data structures for all layers
...     o = wrapper.run(q, kv_cache_at_layer[i])
...     outputs.append(o)
...
>>> outputs[0].shape
torch.Size([7, 64, 128])
__init__(num_levels, float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD') None#

Constructor of MultiLevelCascadeAttentionWrapper.

Parameters:
  • num_levels (int) – The number of levels in the cascade attention.

  • float_workspace_buffer (torch.Tensor) – The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.

  • kv_layout (str) – The layout of the input k/v tensors, could be either NHD or HND.

plan(qo_indptr_arr: List[torch.Tensor], paged_kv_indptr_arr: List[torch.Tensor], paged_kv_indices_arr: List[torch.Tensor], paged_kv_last_page_len: List[torch.Tensor], num_qo_heads: int, num_kv_heads: int, head_dim: int, page_size: int, causal: bool = False, pos_encoding_mode: str = 'NONE', allow_fp16_qk_reduction: bool = False, sm_scale: float | None = None, window_left: int = -1, logits_soft_cap: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, q_data_type: str = 'float16')#

Create auxiliary data structures for multi-level cascade attention for multiple forward calls within the same decode step. Please check Multi-level Cascade Inference Data Layout for data layout in cascade inference.

Parameters:
  • qo_indptr_arr (List[torch.Tensor]) – An array of qo indptr tensors for each level, the array length should be equal to the number of levels. The last element of each tensor should be the total number of queries/outputs.

  • paged_kv_indptr_arr (List[torch.Tensor]) – An array of paged kv-cache indptr tensors for each level, the array length should be equal to the number of levels.

  • paged_kv_indices_arr (List[torch.Tensor]) – An array of paged kv-cache indices tensors for each level, the array length should be equal to the number of levels.

  • paged_kv_last_page_len (List[torch.Tensor]) – An array of paged kv-cache last page length tensors for each level, the array length should be equal to the number of levels.

  • num_qo_heads (int) – The number of query/output heads.

  • num_kv_heads (int) – The number of key/value heads.

  • head_dim (int) – The dimension of the heads.

  • page_size (int) – The page size of the paged kv-cache.

  • causal (bool) – Whether to apply causal mask to the attention matrix. This is only effective when custom_mask is not provided in plan().

  • pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be NONE/ROPE_LLAMA (LLAMA style rotary embedding) /ALIBI. Default is NONE.

  • allow_fp16_qk_reduction (bool) – Whether to use f16 for qk reduction (faster at the cost of slight precision loss).

  • window_left (int) – The left (inclusive) window size for the attention window, when set to -1, the window size will be set to the full length of the sequence. Defaults to -1.

  • logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to 0. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.

  • sm_scale (Optional[float]) – The scale used in softmax, if not provided, will be set to 1.0 / sqrt(head_dim).

  • rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to 1.0.

  • rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to 1e4.

  • q_data_type (Optional[Union[str, torch.dtype]]) – The data type of the query tensor. If None, will be set to torch.float16.

reset_workspace_buffer(float_workspace_buffer: torch.Tensor, int_workspace_buffers: List[torch.Tensor]) None#

Reset the workspace buffer.

Parameters:
  • float_workspace_buffer (torch.Tensor) – The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors.

  • int_workspace_buffers (List[torch.Tensor]) – The array of new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors.

run(q: torch.Tensor, paged_kv_cache: torch.Tensor)#

Compute multi-level cascade attention.

Parameters:
  • q (torch.Tensor) – The query tensor, shape: [batch_size, num_qo_heads, head_dim].

  • paged_kv_cache (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) –

    The paged KV-Cache stored as a tuple of tensors or a single tensor:

    • a tuple (k_cache, v_cache) of 4-D tensors, each with shape: [max_num_pages, page_size, num_kv_heads, head_dim] if kv_layout is NHD, and [max_num_pages, num_kv_heads, page_size, head_dim] if kv_layout is HND.

    • a single 5-D tensor with shape: [max_num_pages, 2, page_size, num_kv_heads, head_dim] if kv_layout is NHD, and [max_num_pages, 2, num_kv_heads, page_size, head_dim] if kv_layout is HND. Where paged_kv_cache[:, 0] is the key-cache and paged_kv_cache[:, 1] is the value-cache.

class flashinfer.cascade.BatchDecodeWithSharedPrefixPagedKVCacheWrapper(float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD')#

Wrapper class for decode attention with shared-prefix paged kv-cache for batch of requests. The shared-prefix KV-Cache was stored in a standalone tensors, and the unique KV-Cache of each request was stored in a paged KV-Cache data stucture.

Check our tutorial for page table layout.

Warning

This API will be deprecated in the future, please use MultiLevelCascadeAttentionWrapper instead.

Example

>>> import torch
>>> import flashinfer
>>> num_layers = 32
>>> num_qo_heads = 64
>>> num_kv_heads = 8
>>> head_dim = 128
>>> max_num_pages = 128
>>> page_size = 16
>>> # allocate 128MB workspace buffer
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
>>> wrapper = flashinfer.BatchDecodeWithSharedPrefixPagedKVCacheWrapper(
...     workspace_buffer, "NHD"
... )
>>> batch_size = 7
>>> shared_prefix_len = 8192
>>> unique_kv_page_indices = torch.arange(max_num_pages).int().to("cuda:0")
>>> unique_kv_page_indptr = torch.tensor(
...     [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0"
... )
>>> # 1 <= kv_last_page_len <= page_size
>>> unique_kv_last_page_len = torch.tensor(
...     [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0"
... )
>>> unique_kv_cache_at_layer = [
...     torch.randn(
...         max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
...     ) for _ in range(num_layers)
... ]
>>> shared_k_data_at_layer = [
...     torch.randn(
...         shared_prefix_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
...     ) for _ in range(num_layers)
... ]
>>> shared_v_data_at_layer = [
...     torch.randn(
...         shared_prefix_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
...     ) for _ in range(num_layers)
... ]
>>> # create auxiliary data structures for batch decode attention
>>> wrapper.begin_forward(
...     unique_kv_page_indptr,
...     unique_kv_page_indices,
...     unique_kv_last_page_len,
...     num_qo_heads,
...     num_kv_heads,
...     head_dim,
...     page_size,
...     data_type=torch.float16
... )
>>> outputs = []
>>> for i in range(num_layers):
...     q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0")
...     k_shared = shared_k_data_at_layer[i]
...     v_shared = shared_v_data_at_layer[i]
...     unique_kv_cache = unique_kv_cache_at_layer[i]
...     # compute batch decode attention, reuse auxiliary data structures for all layers
...     o = wrapper.forward(q, k_shared, v_shared, unique_kv_cache)
...     outputs.append(o)
...
>>> outputs[0].shape
torch.Size([7, 64, 128])

Note

To accelerate computation, FlashInfer’s shared prefix batch decode attention creates some auxiliary data structures, these data structures can be reused across multiple batch decode attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.

__init__(float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD') None#
begin_forward(unique_kv_indptr: torch.Tensor, unique_kv_indices: torch.Tensor, unique_kv_last_page_len: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, page_size: int, data_type: str = 'float16') None#

Plan shared-prefix batch decode attention for given problem specification.

Parameters:
  • indptr (torch.Tensor) – The indptr of the paged kv cache, shape: [batch_size + 1]

  • indices (torch.Tensor) – The page indices of the paged kv cache, shape: [qo_indptr[-1]]

  • last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv cache, shape: [batch_size]

  • num_qo_heads (int) – The number of query/output heads

  • num_kv_heads (int) – The number of key/value heads

  • head_dim (int) – The dimension of the heads

  • page_size (int) – The page size of the paged kv cache

  • data_type (Union[str, torch.dtype]) – The data type of the paged kv cache

Note

The begin_forward() method should be called before any forward() or forward_return_lse() calls, auxiliary data structures will be created during this call and cached for multiple forward calls.

The num_qo_heads must be a multiple of num_kv_heads. If num_qo_heads is not equal to num_kv_heads, the function will use grouped query attention.

end_forward() None#

Warning: this function is deprecated and has no effect

forward(q: torch.Tensor, k_shared: torch.Tensor, v_shared: torch.Tensor, unique_kv_cache: torch.Tensor) torch.Tensor#

Compute batch decode attention between queries and shared-prefix paged kv-cache.

Parameters:
  • q (torch.Tensor) – The query tensor, shape: [batch_size, num_qo_heads, head_dim].

  • k_shared (torch.Tensor) – The shared prefix key tensor, shape: [shared_prefix_len, num_kv_heads, head_dim] if kv_layout is NHD, or [num_kv_heads, shared_prefix_len, head_dim] if kv_layout is HND.

  • v_shared (torch.Tensor) – The shared prefix value tensor, shape: [shared_prefix_len, num_kv_heads, head_dim] if kv_layout is NHD, or [num_kv_heads, shared_prefix_len, head_dim] if kv_layout is HND.

  • unique_kv_cache (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) –

    The request-independent suffix paged KV-Cache stored as a tuple of tensors or a single tensor:

    • a tuple (k_cache, v_cache) of 4-D tensors, each with shape: [max_num_pages, page_size, num_kv_heads, head_dim] if kv_layout is NHD, and [max_num_pages, num_kv_heads, page_size, head_dim] if kv_layout is HND.

    • a single 5-D tensor with shape: [max_num_pages, 2, page_size, num_kv_heads, head_dim] if kv_layout is NHD, and [max_num_pages, 2, num_kv_heads, page_size, head_dim] if kv_layout is HND. Where paged_kv_cache[:, 0] is the key-cache and paged_kv_cache[:, 1] is the value-cache.

Returns:

V – The attention output, shape: [batch_size, num_heads, head_dim]

Return type:

torch.Tensor

reset_workspace_buffer(float_workspace_buffer: torch.Tensor, int_workspace_buffer) None#

Reset the workspace buffer.

Parameters:
  • float_workspace_buffer (torch.Tensor) – The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors.

  • int_workspace_buffer (torch.Tensor) – The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors.

class flashinfer.cascade.BatchPrefillWithSharedPrefixPagedKVCacheWrapper(float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD')#

Wrapper class for prefill/append attention with shared-prefix paged kv-cache for batch of requests.

Check our tutorial for paged kv-cache layout.

Warning

This API will be deprecated in the future, please use MultiLevelCascadeAttentionWrapper instead.

Example

>>> import torch
>>> import flashinfer
>>> num_layers = 32
>>> num_qo_heads = 64
>>> num_kv_heads = 16
>>> head_dim = 128
>>> max_num_pages = 128
>>> page_size = 16
>>> # allocate 128MB workspace buffer
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
>>> prefill_wrapper = flashinfer.BatchPrefillWithSharedPrefixPagedKVCacheWrapper(
...     workspace_buffer, "NHD"
... )
>>> batch_size = 7
>>> shared_prefix_len = 8192
>>> nnz_qo = 100
>>> qo_indptr = torch.tensor(
...     [0, 33, 44, 55, 66, 77, 88, nnz_qo], dtype=torch.int32, device="cuda:0"
... )
>>> paged_kv_indices = torch.arange(max_num_pages).int().to("cuda:0")
>>> paged_kv_indptr = torch.tensor(
...     [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0"
... )
>>> # 1 <= paged_kv_last_page_len <= page_size
>>> paged_kv_last_page_len= torch.tensor(
...     [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0"
... )
>>> kv_cache_at_layer = [
...     torch.randn(
...         max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
...     ) for _ in range(num_layers)
... ]
>>> shared_k_data_at_layer = [
...     torch.randn(
...         shared_prefix_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
...     ) for _ in range(num_layers)
... ]
>>> shared_v_data_at_layer = [
...     torch.randn(
...         shared_prefix_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
...     ) for _ in range(num_layers)
... ]
>>> # create auxiliary data structures for batch prefill attention
>>> prefill_wrapper.begin_forward(
...     qo_indptr,
...     paged_kv_indptr,
...     paged_kv_indices,
...     paged_kv_last_page_len,
...     num_qo_heads,
...     num_kv_heads,
...     head_dim,
...     page_size,
... )
>>> outputs = []
>>> for i in range(num_layers):
...     q = torch.randn(nnz_qo, num_qo_heads, head_dim).half().to("cuda:0")
...     kv_cache = kv_cache_at_layer[i]
...     k_shared = shared_k_data_at_layer[i]
...     v_shared = shared_v_data_at_layer[i]
...     # compute batch prefill attention, reuse auxiliary data structures
...     o = prefill_wrapper.forward(
...         q, k_shared, v_shared, kv_cache, causal=True
...     )
...     outputs.append(o)
...
s[0].shape>>> # clear auxiliary data structures
>>> prefill_wrapper.end_forward()
>>> outputs[0].shape
torch.Size([100, 64, 128])

Note

To accelerate computation, FlashInfer’s shared-prefix batch prefill/append attention operators creates some auxiliary data structures, these data structures can be reused across multiple prefill/append attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.

__init__(float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD') None#

Constructor of BatchDecodeWithSharedPrefixPagedKVCacheWrapper.

Parameters:
  • float_workspace_buffer (torch.Tensor) – The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.

  • kv_layout (str) – The layout of the input k/v tensors, could be either NHD or HND.

begin_forward(qo_indptr: torch.Tensor, paged_kv_indptr: torch.Tensor, paged_kv_indices: torch.Tensor, paged_kv_last_page_len: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, page_size: int) None#

Create auxiliary data structures for shared-prefix batch prefill/append attention for multiple forward calls within the same prefill/append step.

Parameters:
  • qo_indptr (torch.Tensor) – The indptr of the query/output tensor, shape: [batch_size + 1].

  • paged_kv_indptr (torch.Tensor) – The indptr of the paged kv-cache, shape: [batch_size + 1].

  • paged_kv_indices (torch.Tensor) – The page indices of the paged kv-cache, shape: [qo_indptr[-1]].

  • paged_kv_last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv-cache, shape: [batch_size].

  • num_qo_heads (int) – The number of query/output heads.

  • num_kv_heads (int) – The number of key/value heads.

  • head_dim (int) – The dimension of the heads.

  • page_size (int) – The page size of the paged kv-cache.

Note

The begin_forward() method should be called before any forward() or forward_return_lse() calls, auxiliary data structures will be created during this call and cached for multiple forward calls.

The num_qo_heads must be a multiple of num_kv_heads. If num_qo_heads is not equal to num_kv_heads, the function will use grouped query attention.

end_forward() None#

Warning: this function is deprecated and has no effect

forward(q: torch.Tensor, k_shared: torch.Tensor, v_shared: torch.Tensor, unique_kv_cache: torch.Tensor, causal: bool = True, allow_fp16_qk_reduction: bool = False, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None) torch.Tensor#

Compute batch prefill/append attention between query and shared-prefix paged kv-cache.

Parameters:
  • q (torch.Tensor) – The query tensor, shape: [qo_indptr[-1], num_qo_heads, head_dim].

  • k_shared (torch.Tensor) – The shared prefix key tensor, shape: [shared_prefix_len, num_kv_heads, head_dim] if kv_layout is NHD, or [num_kv_heads, shared_prefix_len, head_dim] if kv_layout is HND.

  • torch.Tensor (v_shared ;) – The shared prefix value tensor, shape: [shared_prefix_len, num_kv_heads, head_dim] if kv_layout is NHD, or [num_kv_heads, shared_prefix_len, head_dim] if kv_layout is HND.

  • unique_kv_cache (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) –

    The request-independent suffix paged KV-Cache stored as a tuple of tensors or a single tensor:

    • a tuple (k_cache, v_cache) of 4-D tensors, each with shape: [max_num_pages, page_size, num_kv_heads, head_dim] if kv_layout is NHD, and [max_num_pages, num_kv_heads, page_size, head_dim] if kv_layout is HND.

    • a single 5-D tensor with shape: [max_num_pages, 2, page_size, num_kv_heads, head_dim] if kv_layout is NHD, and [max_num_pages, 2, num_kv_heads, page_size, head_dim] if kv_layout is HND. Where paged_kv_cache[:, 0] is the key-cache and paged_kv_cache[:, 1] is the value-cache.

  • causal (bool) – Whether to apply causal mask on the attention matrix.

  • allow_fp16_qk_reduction (bool) – Whether to use f16 for qk reduction (faster at the cost of slight precision loss).

  • sm_scale (Optional[float]) – The scale of softmax, if not provided, will be set to 1 / sqrt(head_dim).

  • rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to 1.0.

  • rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to 1e4.

Returns:

V – The attention output, shape: [qo_indptr[-1], num_heads, head_dim].

Return type:

torch.Tensor

reset_workspace_buffer(float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor) None#

Reset the workspace buffer.

Parameters:
  • float_workspace_buffer (torch.Tensor) – The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors.

  • int_workspace_buffer (torch.Tensor) – The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors.