flashinfer.prefill#
Attention kernels for prefill & append attention in both single request and batch serving setting.
Single Request Prefill/Append Attention#
Prefill/Append attention with KV cache for single request, return the attention output. |
|
Prefill/Append attention with KV cache for single request, return the attention output. |
Batch Prefill/Append Attention#
- class flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, qo_indptr_buf: torch.Tensor | None = None, paged_kv_indptr_buf: torch.Tensor | None = None, paged_kv_indices_buf: torch.Tensor | None = None, paged_kv_last_page_len_buf: torch.Tensor | None = None, custom_mask_buf: torch.Tensor | None = None, qk_indptr_buf: torch.Tensor | None = None)#
Wrapper class for prefill/append attention with paged kv-cache for batch of requests.
Check our tutorial for page table layout.
Example
>>> import torch >>> import flashinfer >>> num_layers = 32 >>> num_qo_heads = 64 >>> num_kv_heads = 16 >>> head_dim = 128 >>> max_num_pages = 128 >>> page_size = 16 >>> # allocate 128MB workspace buffer >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") >>> prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( ... workspace_buffer, "NHD" ... ) >>> batch_size = 7 >>> nnz_qo = 100 >>> qo_indptr = torch.tensor( ... [0, 33, 44, 55, 66, 77, 88, nnz_qo], dtype=torch.int32, device="cuda:0" ... ) >>> paged_kv_indices = torch.arange(max_num_pages).int().to("cuda:0") >>> paged_kv_indptr = torch.tensor( ... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0" ... ) >>> # 1 <= paged_kv_last_page_len <= page_size >>> paged_kv_last_page_len = torch.tensor( ... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0" ... ) >>> q_at_layer = torch.randn(num_layers, nnz_qo, num_qo_heads, head_dim).half().to("cuda:0") >>> kv_cache_at_layer = torch.randn( ... num_layers, max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ... ) >>> # create auxiliary data structures for batch prefill attention >>> prefill_wrapper.plan( ... qo_indptr, ... paged_kv_indptr, ... paged_kv_indices, ... paged_kv_last_page_len, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... page_size, ... causal=True, ... ) >>> outputs = [] >>> for i in range(num_layers): ... q = q_at_layer[i] ... kv_cache = kv_cache_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures ... o = prefill_wrapper.run(q, kv_cache) ... outputs.append(o) ... >>> outputs[0].shape torch.Size([100, 64, 128]) >>> >>> # below is another example of creating custom mask for batch prefill attention >>> mask_arr = [] >>> qo_len = (qo_indptr[1:] - qo_indptr[:-1]).cpu().tolist() >>> kv_len = (page_size * (paged_kv_indptr[1:] - paged_kv_indptr[:-1] - 1) + paged_kv_last_page_len).cpu().tolist() >>> for i in range(batch_size): ... mask_i = torch.tril( ... torch.full((qo_len[i], kv_len[i]), True, device="cuda:0"), ... diagonal=(kv_len[i] - qo_len[i]), ... ) ... mask_arr.append(mask_i.flatten()) ... >>> mask = torch.cat(mask_arr, dim=0) >>> prefill_wrapper.plan( ... qo_indptr, ... paged_kv_indptr, ... paged_kv_indices, ... paged_kv_last_page_len, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... page_size, ... custom_mask=mask, ... ) >>> for i in range(num_layers): ... q = q_at_layer[i] ... kv_cache = kv_cache_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures ... o_custom = prefill_wrapper.run(q, kv_cache) ... assert torch.allclose(o_custom, outputs[i], rtol=1e-3, atol=1e-3) ...
Note
To accelerate computation, FlashInfer’s batch prefill/append attention operators creates some auxiliary data structures, these data structures can be reused across multiple prefill/append attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.
- __init__(float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, qo_indptr_buf: torch.Tensor | None = None, paged_kv_indptr_buf: torch.Tensor | None = None, paged_kv_indices_buf: torch.Tensor | None = None, paged_kv_last_page_len_buf: torch.Tensor | None = None, custom_mask_buf: torch.Tensor | None = None, qk_indptr_buf: torch.Tensor | None = None) None #
Constructor of
BatchPrefillWithPagedKVCacheWrapper
.- Parameters:
float_workspace_buffer (torch.Tensor) – The user reserved workspace buffer used to store intermediate attention results in split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
kv_layout (str) – The layout of the input k/v tensors, could be either
NHD
orHND
.use_cuda_graph (bool) – Whether to enable CUDA graph capture for the prefill kernels, if enabled, the auxiliary data structures will be stored in provided buffers. The
batch_size
cannot change during the lifecycle of this wrapper when CUDAGraph is enabled.qo_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
qo_indptr
array, the size of the buffer should be[batch_size + 1]
. This argument is only effective whenuse_cuda_graph
isTrue
.paged_kv_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
paged_kv_indptr
array, the size of this buffer should be[batch_size + 1]
. This argument is only effective whenuse_cuda_graph
isTrue
.paged_kv_indices_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
paged_kv_indices
array, should be large enough to store the maximum possible size of thepaged_kv_indices
array during the lifetime of the wrapper. This argument is only effective whenuse_cuda_graph
isTrue
.paged_kv_last_page_len_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
paged_kv_last_page_len
array, the size of the buffer should be[batch_size]
. This argument is only effective whenuse_cuda_graph
isTrue
.custom_mask_buf (Optional[torch.Tensor]) – The user reserved buffer to store the custom mask tensor, should be large enough to store the maximum possible size of the packed custom mask tensor during the lifetime of the wrapper. This argument is only effective when
use_cuda_graph
is set toTrue
and the custom mask will be used in attention computation.qk_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the
qk_indptr
array, the size of the buffer should be[batch_size + 1]
. This argument is only effective whenuse_cuda_graph
isTrue
and the custom mask will be used in attention computation.
- plan(qo_indptr: torch.Tensor, paged_kv_indptr: torch.Tensor, paged_kv_indices: torch.Tensor, paged_kv_last_page_len: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, page_size: int, custom_mask: torch.Tensor | None = None, packed_custom_mask: torch.Tensor | None = None, causal: bool = False, pos_encoding_mode: str = 'NONE', allow_fp16_qk_reduction: bool = False, sm_scale: float | None = None, window_left: int = -1, logits_soft_cap: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, q_data_type: str | torch.dtype = 'float16', kv_data_type: str | torch.dtype | None = None, non_blocking: bool = False) None #
Plan batch prefill/append attention on Paged KV-Cache for given problem specification.
- Parameters:
qo_indptr (torch.Tensor) – The indptr of the query/output tensor, shape:
[batch_size + 1]
.paged_kv_indptr (torch.Tensor) – The indptr of the paged kv-cache, shape:
[batch_size + 1]
.paged_kv_indices (torch.Tensor) – The page indices of the paged kv-cache, shape:
[qo_indptr[-1]]
.paged_kv_last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv-cache, shape:
[batch_size]
.num_qo_heads (int) – The number of query/output heads.
num_kv_heads (int) – The number of key/value heads.
head_dim (int) – The dimension of the heads.
page_size (int) – The size of each page in the paged kv-cache.
custom_mask (Optional[torch.Tensor]) –
The flattened boolean mask tensor, shape:
(sum(q_len[i] * k_len[i] for i in range(batch_size))
. The elements in the mask tensor should be eitherTrue
orFalse
, whereFalse
means the corresponding element in the attention matrix will be masked out.Please refer to the mask layout for more details about flattened layout of mask tensor.
When
custom_mask
is provided, andpacked_custom_mask
is not, the function will pack the custom mask tensor into a 1D packed mask tensor, which introduces additional overhead.packed_custom_mask (Optional[torch.Tensor]) – The 1D packed uint8 mask tensor, if provided, the
custom_mask
will be ignored. The packed mask tensor is generated byflashinfer.quantization.packbits()
.causal (bool) – Whether to apply causal mask to the attention matrix. This is only effective when
custom_mask
is not provided inplan()
.pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be
NONE
/ROPE_LLAMA
(LLAMA style rotary embedding) /ALIBI
. Default isNONE
.allow_fp16_qk_reduction (bool) – Whether to use f16 for qk reduction (faster at the cost of slight precision loss).
window_left (int) – The left (inclusive) window size for the attention window, when set to
-1
, the window size will be set to the full length of the sequence. Defaults to-1
.logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to
0
. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.sm_scale (Optional[float]) – The scale used in softmax, if not provided, will be set to
1.0 / sqrt(head_dim)
.rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to
1.0
.rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to
1e4
.q_data_type (Union[str, torch.dtype]) – The data type of the query tensor, defaults torch.float16.
kv_data_type (Optional[Union[str, torch.dtype]]) – The data type of the key/value tensor. If None, will be set to
q_data_type
.non_blocking (bool) – Whether to copy the input tensors to the device asynchronously, defaults to
False
. IfTrue
, user should synchronize before callingrun()
or cuda graph replay.
Note
The
plan()
method should be called before anyrun()
orrun_return_lse()
calls, auxiliary data structures will be created during this call and cached for multiple kernel runs.The
num_qo_heads
must be a multiple ofnum_kv_heads
. Ifnum_qo_heads
is not equal tonum_kv_heads
, the function will use grouped query attention.The
plan()
method cannot be used in Cuda Graph or intorch.compile
.
- reset_workspace_buffer(float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor) None #
Reset the workspace buffer.
- Parameters:
float_workspace_buffer (torch.Tensor) – The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors.
int_workspace_buffer (torch.Tensor) – The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors.
- run(q: torch.Tensor, paged_kv_cache: torch.Tensor | Tuple[torch.Tensor, torch.Tensor], k_scale: float | None = None, v_scale: float | None = None, return_lse: Literal[False] = False) torch.Tensor #
- run(q: torch.Tensor, paged_kv_cache: torch.Tensor | Tuple[torch.Tensor, torch.Tensor], k_scale: float | None = None, v_scale: float | None = None, return_lse: Literal[True] = True) Tuple[torch.Tensor, torch.Tensor]
Compute batch prefill/append attention between query and paged kv-cache.
- Parameters:
q (torch.Tensor) – The query tensor, shape:
[qo_indptr[-1], num_qo_heads, head_dim]
paged_kv_cache (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) –
The paged KV-Cache stored as a tuple of tensors or a single tensor:
a tuple
(k_cache, v_cache)
of 4-D tensors, each with shape:[max_num_pages, page_size, num_kv_heads, head_dim]
ifkv_layout
isNHD
, and[max_num_pages, num_kv_heads, page_size, head_dim]
ifkv_layout
isHND
.a single 5-D tensor with shape:
[max_num_pages, 2, page_size, num_kv_heads, head_dim]
ifkv_layout
isNHD
, and[max_num_pages, 2, num_kv_heads, page_size, head_dim]
ifkv_layout
isHND
. Wherepaged_kv_cache[:, 0]
is the key-cache andpaged_kv_cache[:, 1]
is the value-cache.
k_scale (Optional[float]) – The calibration scale of key for fp8 input, if not provided, will be set to
1.0
.v_scale (Optional[float]) – The calibration scale of value for fp8 input, if not provided, will be set to
1.0
.return_lse (bool) – Whether to return the logsumexp of attention output
- Returns:
If
return_lse
isFalse
, the attention output, shape:[qo_indptr[-1], num_qo_heads, head_dim]
. Ifreturn_lse
isTrue
, a tuple of two tensors:The attention output, shape:
[qo_indptr[-1], num_qo_heads, head_dim]
.The logsumexp of attention output, shape:
[qo_indptr[-1], num_qo_heads]
.
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
- class flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, qo_indptr_buf: torch.Tensor | None = None, kv_indptr_buf: torch.Tensor | None = None, custom_mask_buf: torch.Tensor | None = None, qk_indptr_buf: torch.Tensor | None = None)#
Wrapper class for prefill/append attention with ragged (tensor) kv-cache for batch of requests.
Check our tutorial for ragged kv-cache layout.
Example
>>> import torch >>> import flashinfer >>> num_layers = 32 >>> num_qo_heads = 64 >>> num_kv_heads = 16 >>> head_dim = 128 >>> # allocate 128MB workspace buffer >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") >>> prefill_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( ... workspace_buffer, "NHD" ... ) >>> batch_size = 7 >>> nnz_kv = 100 >>> nnz_qo = 100 >>> qo_indptr = torch.tensor( ... [0, 33, 44, 55, 66, 77, 88, nnz_qo], dtype=torch.int32, device="cuda:0" ... ) >>> kv_indptr = qo_indptr.clone() >>> q_at_layer = torch.randn(num_layers, nnz_qo, num_qo_heads, head_dim).half().to("cuda:0") >>> k_at_layer = torch.randn(num_layers, nnz_kv, num_kv_heads, head_dim).half().to("cuda:0") >>> v_at_layer = torch.randn(num_layers, nnz_kv, num_kv_heads, head_dim).half().to("cuda:0") >>> # create auxiliary data structures for batch prefill attention >>> prefill_wrapper.plan( ... qo_indptr, ... kv_indptr, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... causal=True, ... ) >>> outputs = [] >>> for i in range(num_layers): ... q = q_at_layer[i] ... k = k_at_layer[i] ... v = v_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures ... o = prefill_wrapper.run(q, k, v) ... outputs.append(o) ... >>> outputs[0].shape torch.Size([100, 64, 128]) >>> >>> # below is another example of creating custom mask for batch prefill attention >>> mask_arr = [] >>> qo_len = (qo_indptr[1:] - qo_indptr[:-1]).cpu().tolist() >>> kv_len = (kv_indptr[1:] - kv_indptr[:-1]).cpu().tolist() >>> for i in range(batch_size): ... mask_i = torch.tril( ... torch.full((qo_len[i], kv_len[i]), True, device="cuda:0"), ... diagonal=(kv_len[i] - qo_len[i]), ... ) ... mask_arr.append(mask_i.flatten()) ... >>> mask = torch.cat(mask_arr, dim=0) >>> prefill_wrapper.plan( ... qo_indptr, ... kv_indptr, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... custom_mask=mask ... ) >>> outputs_custom_mask = [] >>> for i in range(num_layers): ... q = q_at_layer[i] ... k = k_at_layer[i] ... v = v_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures ... o_custom = prefill_wrapper.run(q, k, v) ... assert torch.allclose(o_custom, outputs[i], rtol=1e-3, atol=1e-3) ... >>> outputs_custom_mask[0].shape torch.Size([100, 64, 128])
Note
To accelerate computation, FlashInfer’s batch prefill/append attention operators creates some auxiliary data structures, these data structures can be reused across multiple prefill/append attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.
- __init__(float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, qo_indptr_buf: torch.Tensor | None = None, kv_indptr_buf: torch.Tensor | None = None, custom_mask_buf: torch.Tensor | None = None, qk_indptr_buf: torch.Tensor | None = None) None #
Constructor of
BatchPrefillWithRaggedKVCacheWrapper
.- Parameters:
float_workspace_buffer (torch.Tensor) – The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
kv_layout (str) – The layout of the input k/v tensors, could be either
NHD
orHND
.use_cuda_graph (bool) – Whether to enable CUDA graph capture for the prefill kernels, if enabled, the auxiliary data structures will be stored as the provided buffers.
qo_indptr_buf (Optional[torch.Tensor]) – The user reserved GPU buffer to store the
qo_indptr
array, the size of the buffer should be[batch_size + 1]
. This argument is only effective whenuse_cuda_graph
isTrue
.kv_indptr_buf (Optional[torch.Tensor]) – The user reserved GPU buffer to store the
kv_indptr
array, the size of the buffer should be[batch_size + 1]
. This argument is only effective whenuse_cuda_graph
isTrue
.custom_mask_buf (Optional[torch.Tensor]) – The user reserved GPU buffer to store the custom mask tensor, should be large enough to store the maximum possible size of the packed custom mask tensor during the lifetime of the wrapper. This argument is only effective when
use_cuda_graph
isTrue
and custom mask will be used in attention computation.qk_indptr_buf (Optional[torch.Tensor]) – The user reserved GPU buffer to store the
qk_indptr
array, the size of the buffer should be[batch_size]
. This argument is only effective whenuse_cuda_graph
isTrue
and custom mask will be used in attention computation.
- plan(qo_indptr: torch.Tensor, kv_indptr: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, custom_mask: torch.Tensor | None = None, packed_custom_mask: torch.Tensor | None = None, causal: bool = False, pos_encoding_mode: str = 'NONE', allow_fp16_qk_reduction: bool = False, window_left: int = -1, logits_soft_cap: float | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, q_data_type: str = 'float16', kv_data_type: str | None = None) None #
Plan batch prefill/append attention on Ragged KV-Cache for given problem specification.
- Parameters:
qo_indptr (torch.Tensor) – The indptr of the query/output tensor, shape:
[batch_size + 1]
.kv_indptr (torch.Tensor) – The indptr of the key/value tensor, shape:
[batch_size + 1]
.num_qo_heads (int) – The number of query/output heads.
num_kv_heads (int) – The number of key/value heads.
head_dim (int) – The dimension of the heads.
custom_mask (Optional[torch.Tensor]) –
The flattened boolean mask tensor, shape:
(sum(q_len[i] * k_len[i] for i in range(batch_size))
. The elements in the mask tensor should be eitherTrue
orFalse
, whereFalse
means the corresponding element in the attention matrix will be masked out.Please refer to the mask layout for more details about flattened layout of mask tensor.
When
custom_mask
is provided, andpacked_custom_mask
is not, the function will pack the custom mask tensor into a 1D packed mask tensor, which introduces additional overhead.packed_custom_mask (Optional[torch.Tensor]) –
The 1D packed uint8 mask tensor, if provided, the
custom_mask
will be ignored. The packed mask tensor is generated byflashinfer.quantization.packbits()
.If provided, the custom mask will be added to the attention matrix before softmax and after scaling. The mask tensor should be in the same device as the input tensors.
causal (bool) – Whether to apply causal mask to the attention matrix. This argument is ignored if
mask
is provided inplan()
.pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be
NONE
/ROPE_LLAMA
(LLAMA style rotary embedding) /ALIBI
. Default isNONE
.allow_fp16_qk_reduction (bool) – Whether to use f16 for qk reduction (faster at the cost of slight precision loss).
window_left (int) – The left (inclusive) window size for the attention window, when set to
-1
, the window size will be set to the full length of the sequence. Defaults to-1
.logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to
0
. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.sm_scale (Optional[float]) – The scale used in softmax, if not provided, will be set to
1.0 / sqrt(head_dim)
.rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to
1.0
.rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to
1e4
.q_data_type (Union[str, torch.dtype]) – The data type of the query tensor, defaults to torch.float16.
kv_data_type (Optional[Union[str, torch.dtype]]) – The data type of the key/value tensor. If None, will be set to
q_data_type
.
Note
The
plan()
method should be called before anyrun()
orrun_return_lse()
calls, auxiliary data structures will be created during this plan call and cached for multiple kernel runs.The
num_qo_heads
must be a multiple ofnum_kv_heads
. Ifnum_qo_heads
is not equal tonum_kv_heads
, the function will use grouped query attention.The
plan()
method cannot be used in Cuda Graph or intorch.compile
.
- reset_workspace_buffer(float_workspace_buffer: torch.Tensor, int_workspace_buffer) None #
Reset the workspace buffer.
- Parameters:
float_workspace_buffer (torch.Tensor) – The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors.
int_workspace_buffer (torch.Tensor) – The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors.
- run(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, return_lse: Literal[False] = False) torch.Tensor #
- run(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, return_lse: Literal[True] = True) Tuple[torch.Tensor, torch.Tensor]
Compute batch prefill/append attention between query and kv-cache stored as ragged tensor.
- Parameters:
q (torch.Tensor) – The query tensor, shape:
[qo_indptr[-1], num_qo_heads, head_dim]
k (torch.Tensor) – The key tensor, shape:
[kv_indptr[-1], num_kv_heads, head_dim]
v (torch.Tensor) – The value tensor, shape:
[kv_indptr[-1], num_kv_heads, head_dim]
return_lse (bool) – Whether to return the logsumexp of attention output
- Returns:
If
return_lse
isFalse
, the attention output, shape:[qo_indptr[-1], num_qo_heads, head_dim]
. Ifreturn_lse
isTrue
, a tuple of two tensors:The attention output, shape:
[qo_indptr[-1], num_qo_heads, head_dim]
.The logsumexp of attention output, shape:
[qo_indptr[-1], num_qo_heads]
.
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]