flashinfer.cascade#
Merge Attention States#
|
Merge the attention output |
|
Merge the self-attention state |
|
Merge multiple attention states (v, s). |
Cascade Attention#
Decode attention between queries and shared prefix kv-cache for batch of requests. |
Cascade Attention Wrapper Classes#
Wrapper class for decode attention with shared-prefix paged kv-cache for batch of requests.
Check our tutorial for page table layout.
Example
>>> import torch >>> import flashinfer >>> num_layers = 32 >>> num_qo_heads = 64 >>> num_kv_heads = 8 >>> head_dim = 128 >>> max_num_pages = 128 >>> page_size = 16 >>> # allocate 16MB workspace buffer >>> workspace_buffer = torch.empty(16 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") >>> wrapper = flashinfer.BatchDecodeWithSharedPrefixPagedKVCacheWrapper( ... workspace_buffer, "NHD" ... ) >>> batch_size = 7 >>> shared_prefix_len = 8192 >>> unique_kv_page_indices = torch.arange(max_num_pages).int().to("cuda:0") >>> unique_kv_page_indptr = torch.tensor( ... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0" ... ) >>> # 1 <= kv_last_page_len <= page_size >>> unique_kv_last_page_len = torch.tensor( ... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0" ... ) >>> unique_kv_data_at_layer = [ ... torch.randn( ... max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ... ) for _ in range(num_layers) ... ] >>> shared_k_data_at_layer = [ ... torch.randn( ... shared_prefix_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ... ) for _ in range(num_layers) ... ] >>> shared_v_data_at_layer = [ ... torch.randn( ... shared_prefix_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ... ) for _ in range(num_layers) ... ] >>> # create auxiliary data structures for batch decode attention >>> wrapper.begin_forward( ... unique_kv_page_indptr, ... unique_kv_page_indices, ... unique_kv_last_page_len, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... page_size, ... data_type=torch.float16 ... ) >>> outputs = [] >>> for i in range(num_layers): ... q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0") ... k_shared = shared_k_data_at_layer[i] ... v_shared = shared_v_data_at_layer[i] ... unique_kv_data = unique_kv_data_at_layer[i] ... # compute batch decode attention, reuse auxiliary data structures for all layers ... o = wrapper.forward(q, k_shared, v_shared, unique_kv_data) ... outputs.append(o) ... >>> # clear auxiliary data structures >>> wrapper.end_forward() >>> outputs[0].shape torch.Size([7, 64, 128])
Note
To accelerate computation, FlashInfer’s shared prefix batch decode attention creates some auxiliary data structures, these data structures can be reused across multiple batch decode attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.
Create auxiliary data structures for shared-prefix batch decode for multiple forward calls within the same decode step.
- Parameters:
indptr (torch.Tensor) – The indptr of the paged kv cache, shape:
[batch_size + 1]
indices (torch.Tensor) – The page indices of the paged kv cache, shape:
[qo_indptr[-1]]
last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv cache, shape:
[batch_size]
num_qo_heads (int) – The number of query/output heads
num_kv_heads (int) – The number of key/value heads
head_dim (int) – The dimension of the heads
page_size (int) – The page size of the paged kv cache
data_type (Union[str, torch.dtype]) – The data type of the paged kv cache
Note
The
begin_forward()
method should be called before anyforward()
orforward_return_lse()
calls, auxiliary data structures will be created during this call and cached for multiple forward calls.The
num_qo_heads
must be a multiple ofnum_kv_heads
. Ifnum_qo_heads
is not equal tonum_kv_heads
, the function will use grouped query attention.
Clear auxiliary data structures created by
begin_forward()
.
Compute batch decode attention between queries and shared-prefix paged kv-cache.
- Parameters:
q (torch.Tensor) – The query tensor, shape:
[batch_size, num_qo_heads, head_dim]
.k_shared (torch.Tensor) – The shared prefix key tensor, shape:
[shared_prefix_len, num_kv_heads, head_dim]
ifkv_layout
isNHD
, or[num_kv_heads, shared_prefix_len, head_dim]
ifkv_layout
isHND
.v_shared (torch.Tensor) – The shared prefix value tensor, shape:
[shared_prefix_len, num_kv_heads, head_dim]
ifkv_layout
isNHD
, or[num_kv_heads, shared_prefix_len, head_dim]
ifkv_layout
isHND
.unique_kv_data (torch.Tensor) – A 5-D tensor of paged kv-cache data storing the request-independent suffix key and value tensors, shape:
[max_num_pages, 2, page_size, num_kv_heads, head_dim]
ifkv_layout
isNHD
, or[max_num_pages, 2, page_size, num_kv_heads, head_dim]
ifkv_layout
isHND
.allow_fp16_qk_reduction (bool) – Whether to use f16 for qk reduction (faster at the cost of slight precision loss).
sm_scale (Optional[float]) – The scale of softmax, if not provided, will be set to
1 / sqrt(head_dim)
.rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to
1.0
.rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to
1e4
.
- Returns:
V – The attention output, shape:
[batch_size, num_heads, head_dim]
- Return type:
torch.Tensor
Reset the workspace buffer.
- Parameters:
new_workspace_buffer (torch.Tensor) – The new workspace buffer, the device of the new workspace buffer should be the same as the device of the input tensors.
Wrapper class for prefill/append attention with shared-prefix paged kv-cache for batch of requests.
Check our tutorial for paged kv-cache layout.
Example
>>> import torch >>> import flashinfer >>> num_layers = 32 >>> num_qo_heads = 64 >>> num_kv_heads = 16 >>> head_dim = 128 >>> max_num_pages = 128 >>> page_size = 16 >>> # allocate 16MB workspace buffer >>> workspace_buffer = torch.empty(16 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") >>> prefill_wrapper = flashinfer.BatchPrefillWithSharedPrefixPagedKVCacheWrapper( ... workspace_buffer, "NHD" ... ) >>> batch_size = 7 >>> shared_prefix_len = 8192 >>> nnz_qo = 100 >>> qo_indptr = torch.tensor( ... [0, 33, 44, 55, 66, 77, 88, nnz_qo], dtype=torch.int32, device="cuda:0" ... ) >>> paged_kv_indices = torch.arange(max_num_pages).int().to("cuda:0") >>> paged_kv_indptr = torch.tensor( ... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0" ... ) >>> # 1 <= paged_kv_last_page_len <= page_size >>> paged_kv_last_page_len= torch.tensor( ... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0" ... ) >>> kv_data_at_layer = [ ... torch.randn( ... max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ... ) for _ in range(num_layers) ... ] >>> shared_k_data_at_layer = [ ... torch.randn( ... shared_prefix_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ... ) for _ in range(num_layers) ... ] >>> shared_v_data_at_layer = [ ... torch.randn( ... shared_prefix_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ... ) for _ in range(num_layers) ... ] >>> # create auxiliary data structures for batch prefill attention >>> prefill_wrapper.begin_forward( ... qo_indptr, ... paged_kv_indptr, ... paged_kv_indices, ... paged_kv_last_page_len, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... ) >>> outputs = [] >>> for i in range(num_layers): ... q = torch.randn(nnz_qo, num_qo_heads, head_dim).half().to("cuda:0") ... kv_data = kv_data_at_layer[i] ... k_shared = shared_k_data_at_layer[i] ... v_shared = shared_v_data_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures ... o = prefill_wrapper.forward( ... q, k_shared, v_shared, kv_data, causal=True ... ) ... outputs.append(o) ... s[0].shape>>> # clear auxiliary data structures >>> prefill_wrapper.end_forward() >>> outputs[0].shape torch.Size([100, 64, 128])
Note
To accelerate computation, FlashInfer’s shared-prefix batch prefill/append attention operators creates some auxiliary data structures, these data structures can be reused across multiple prefill/append attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.
Create auxiliary data structures for shared-prefix batch prefill/append attention for multiple forward calls within the same prefill/append step.
- Parameters:
qo_indptr (torch.Tensor) – The indptr of the query/output tensor, shape:
[batch_size + 1]
.paged_kv_indptr (torch.Tensor) – The indptr of the paged kv-cache, shape:
[batch_size + 1]
.paged_kv_indices (torch.Tensor) – The page indices of the paged kv-cache, shape:
[qo_indptr[-1]]
.paged_kv_last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv-cache, shape:
[batch_size]
.num_qo_heads (int) – The number of query/output heads.
num_kv_heads (int) – The number of key/value heads.
head_dim (int) – The dimension of the heads.
Notes
The
begin_forward()
method should be called before anyforward()
orforward_return_lse()
calls, auxiliary data structures will be created during this call and cached for multiple forward calls.The
num_qo_heads
must be a multiple ofnum_kv_heads
. Ifnum_qo_heads
is not equal tonum_kv_heads
, the function will use grouped query attention.
Clear the auxiliary data structures created by
begin_forward()
.
Compute batch prefill/append attention between query and shared-prefix paged kv-cache.
- Parameters:
q (torch.Tensor) – The query tensor, shape:
[qo_indptr[-1], num_qo_heads, head_dim]
.k_shared (torch.Tensor) – The shared prefix key tensor, shape:
[shared_prefix_len, num_kv_heads, head_dim]
ifkv_layout
isNHD
, or[num_kv_heads, shared_prefix_len, head_dim]
ifkv_layout
isHND
.torch.Tensor (v_shared ;) – The shared prefix value tensor, shape:
[shared_prefix_len, num_kv_heads, head_dim]
ifkv_layout
isNHD
, or[num_kv_heads, shared_prefix_len, head_dim]
ifkv_layout
isHND
.unique_kv_data (torch.Tensor) – A 5-D tensor of paged kv-cache data storing the request-independent suffix key and value tensors, shape:
[max_num_pages, 2, page_size, num_kv_heads, head_dim]
ifkv_layout
isNHD
, or[max_num_pages, 2, page_size, num_kv_heads, head_dim]
ifkv_layout
isHND
.causal (bool) – Whether to apply causal mask on the attention matrix.
allow_fp16_qk_reduction (bool) – Whether to use f16 for qk reduction (faster at the cost of slight precision loss).
sm_scale (Optional[float]) – The scale of softmax, if not provided, will be set to
1 / sqrt(head_dim)
.rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to
1.0
.rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to
1e4
.
- Returns:
V – The attention output, shape:
[qo_indptr[-1], num_heads, head_dim]
.- Return type:
torch.Tensor
Reset the workspace buffer.
- Parameters:
new_workspace_buffer (torch.Tensor) – The new workspace buffer, the device of the new workspace buffer should be the same as the device of the input tensors.