flashinfer.prefill

Attention kernels for prefill & append attention in both single request and batch serving setting.

Single Request Prefill/Append Attention

single_prefill_with_kv_cache()

Prefill/Append attention with KV cache for single request, return the attention output.

single_prefill_with_kv_cache_return_lse(q, k, v)

Prefill/Append attention with KV cache for single request, return the attention output.

Batch Prefill/Append Attention

class flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, qo_indptr_buf: torch.Tensor | None = None, paged_kv_indptr_buf: torch.Tensor | None = None, paged_kv_indices_buf: torch.Tensor | None = None, paged_kv_last_page_len_buf: torch.Tensor | None = None, custom_mask_buf: torch.Tensor | None = None, qk_indptr_buf: torch.Tensor | None = None, backend: str = 'auto', jit_module: Any = None)

Wrapper class for prefill/append attention with paged kv-cache for batch of requests.

Check our tutorial for page table layout.

Example

>>> import torch
>>> import flashinfer
>>> num_layers = 32
>>> num_qo_heads = 64
>>> num_kv_heads = 16
>>> head_dim = 128
>>> max_num_pages = 128
>>> page_size = 16
>>> # allocate 128MB workspace buffer
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
>>> prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
...     workspace_buffer, "NHD"
... )
>>> batch_size = 7
>>> nnz_qo = 100
>>> qo_indptr = torch.tensor(
...     [0, 33, 44, 55, 66, 77, 88, nnz_qo], dtype=torch.int32, device="cuda:0"
... )
>>> paged_kv_indices = torch.arange(max_num_pages).int().to("cuda:0")
>>> paged_kv_indptr = torch.tensor(
...     [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0"
... )
>>> # 1 <= paged_kv_last_page_len <= page_size
>>> paged_kv_last_page_len = torch.tensor(
...     [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0"
... )
>>> q_at_layer = torch.randn(num_layers, nnz_qo, num_qo_heads, head_dim).half().to("cuda:0")
>>> kv_cache_at_layer = torch.randn(
...     num_layers, max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
... )
>>> # create auxiliary data structures for batch prefill attention
>>> prefill_wrapper.plan(
...     qo_indptr,
...     paged_kv_indptr,
...     paged_kv_indices,
...     paged_kv_last_page_len,
...     num_qo_heads,
...     num_kv_heads,
...     head_dim,
...     page_size,
...     causal=True,
... )
>>> outputs = []
>>> for i in range(num_layers):
...     q = q_at_layer[i]
...     kv_cache = kv_cache_at_layer[i]
...     # compute batch prefill attention, reuse auxiliary data structures
...     o = prefill_wrapper.run(q, kv_cache)
...     outputs.append(o)
...
>>> outputs[0].shape
torch.Size([100, 64, 128])
>>>
>>> # below is another example of creating custom mask for batch prefill attention
>>> mask_arr = []
>>> qo_len = (qo_indptr[1:] - qo_indptr[:-1]).cpu().tolist()
>>> kv_len = (page_size * (paged_kv_indptr[1:] - paged_kv_indptr[:-1] - 1) + paged_kv_last_page_len).cpu().tolist()
>>> for i in range(batch_size):
...     mask_i = torch.tril(
...         torch.full((qo_len[i], kv_len[i]), True, device="cuda:0"),
...         diagonal=(kv_len[i] - qo_len[i]),
...     )
...     mask_arr.append(mask_i.flatten())
...
>>> mask = torch.cat(mask_arr, dim=0)
>>> prefill_wrapper.plan(
...     qo_indptr,
...     paged_kv_indptr,
...     paged_kv_indices,
...     paged_kv_last_page_len,
...     num_qo_heads,
...     num_kv_heads,
...     head_dim,
...     page_size,
...     custom_mask=mask,
... )
>>> for i in range(num_layers):
...     q = q_at_layer[i]
...     kv_cache = kv_cache_at_layer[i]
...     # compute batch prefill attention, reuse auxiliary data structures
...     o_custom = prefill_wrapper.run(q, kv_cache)
...     assert torch.allclose(o_custom, outputs[i], rtol=1e-3, atol=1e-3)
...

Note

To accelerate computation, FlashInfer’s batch prefill/append attention operators creates some auxiliary data structures, these data structures can be reused across multiple prefill/append attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.

__init__(float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, qo_indptr_buf: torch.Tensor | None = None, paged_kv_indptr_buf: torch.Tensor | None = None, paged_kv_indices_buf: torch.Tensor | None = None, paged_kv_last_page_len_buf: torch.Tensor | None = None, custom_mask_buf: torch.Tensor | None = None, qk_indptr_buf: torch.Tensor | None = None, backend: str = 'auto', jit_module: Any = None) None

Constructor of BatchPrefillWithPagedKVCacheWrapper.

Parameters:
  • float_workspace_buffer (torch.Tensor) – The user reserved workspace buffer used to store intermediate attention results in split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.

  • kv_layout (str) – The layout of the input k/v tensors, could be either NHD or HND.

  • use_cuda_graph (bool) – Whether to enable CUDA graph capture for the prefill kernels, if enabled, the auxiliary data structures will be stored in provided buffers. The batch_size cannot change during the lifecycle of this wrapper when CUDAGraph is enabled.

  • qo_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the qo_indptr array, the size of the buffer should be [batch_size + 1]. This argument is only effective when use_cuda_graph is True.

  • paged_kv_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the paged_kv_indptr array, the size of this buffer should be [batch_size + 1]. This argument is only effective when use_cuda_graph is True.

  • paged_kv_indices_buf (Optional[torch.Tensor]) – The user reserved buffer to store the paged_kv_indices array, should be large enough to store the maximum possible size of the paged_kv_indices array during the lifetime of the wrapper. This argument is only effective when use_cuda_graph is True.

  • paged_kv_last_page_len_buf (Optional[torch.Tensor]) – The user reserved buffer to store the paged_kv_last_page_len array, the size of the buffer should be [batch_size]. This argument is only effective when use_cuda_graph is True.

  • custom_mask_buf (Optional[torch.Tensor]) – The user reserved buffer to store the custom mask tensor, should be large enough to store the maximum possible size of the packed custom mask tensor during the lifetime of the wrapper. This argument is only effective when use_cuda_graph is set to True and the custom mask will be used in attention computation.

  • qk_indptr_buf (Optional[torch.Tensor]) – The user reserved buffer to store the qk_indptr array, the size of the buffer should be [batch_size + 1]. This argument is only effective when use_cuda_graph is True and the custom mask will be used in attention computation.

  • backend (str) – The implementation backend, could be auto/fa2 or fa3. Defaults to auto. If set to auto, the function will automatically choose the backend based on the device architecture and kernel availability.

plan(qo_indptr: torch.Tensor, paged_kv_indptr: torch.Tensor, paged_kv_indices: torch.Tensor, paged_kv_last_page_len: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, page_size: int, custom_mask: torch.Tensor | None = None, packed_custom_mask: torch.Tensor | None = None, causal: bool = False, pos_encoding_mode: str = 'NONE', allow_fp16_qk_reduction: bool = False, sm_scale: float | None = None, window_left: int = -1, logits_soft_cap: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, q_data_type: str | torch.dtype = 'float16', kv_data_type: str | torch.dtype | None = None, non_blocking: bool = False) None

Plan batch prefill/append attention on Paged KV-Cache for given problem specification.

Parameters:
  • qo_indptr (torch.Tensor) – The indptr of the query/output tensor, shape: [batch_size + 1].

  • paged_kv_indptr (torch.Tensor) – The indptr of the paged kv-cache, shape: [batch_size + 1].

  • paged_kv_indices (torch.Tensor) – The page indices of the paged kv-cache, shape: [qo_indptr[-1]].

  • paged_kv_last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv-cache, shape: [batch_size].

  • num_qo_heads (int) – The number of query/output heads.

  • num_kv_heads (int) – The number of key/value heads.

  • head_dim (int) – The dimension of the heads.

  • page_size (int) – The size of each page in the paged kv-cache.

  • custom_mask (Optional[torch.Tensor]) –

    The flattened boolean mask tensor, shape: (sum(q_len[i] * k_len[i] for i in range(batch_size)). The elements in the mask tensor should be either True or False, where False means the corresponding element in the attention matrix will be masked out.

    Please refer to the mask layout for more details about flattened layout of mask tensor.

    When custom_mask is provided, and packed_custom_mask is not, the function will pack the custom mask tensor into a 1D packed mask tensor, which introduces additional overhead.

  • packed_custom_mask (Optional[torch.Tensor]) – The 1D packed uint8 mask tensor, if provided, the custom_mask will be ignored. The packed mask tensor is generated by flashinfer.quantization.packbits().

  • causal (bool) – Whether to apply causal mask to the attention matrix. This is only effective when custom_mask is not provided in plan().

  • pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be NONE/ROPE_LLAMA (LLAMA style rotary embedding) /ALIBI. Default is NONE.

  • allow_fp16_qk_reduction (bool) – Whether to use f16 for qk reduction (faster at the cost of slight precision loss).

  • window_left (int) – The left (inclusive) window size for the attention window, when set to -1, the window size will be set to the full length of the sequence. Defaults to -1.

  • logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to 0. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.

  • sm_scale (Optional[float]) – The scale used in softmax, if not provided, will be set to 1.0 / sqrt(head_dim).

  • rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to 1.0.

  • rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to 1e4.

  • q_data_type (Union[str, torch.dtype]) – The data type of the query tensor, defaults torch.float16.

  • kv_data_type (Optional[Union[str, torch.dtype]]) – The data type of the key/value tensor. If None, will be set to q_data_type.

  • non_blocking (bool) – Whether to copy the input tensors to the device asynchronously, defaults to False. If True, user should synchronize before calling run() or cuda graph replay.

Note

The plan() method should be called before any run() or run_return_lse() calls, auxiliary data structures will be created during this call and cached for multiple kernel runs.

The num_qo_heads must be a multiple of num_kv_heads. If num_qo_heads is not equal to num_kv_heads, the function will use grouped query attention.

The plan() method cannot be used in Cuda Graph or in torch.compile.

reset_workspace_buffer(float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor) None

Reset the workspace buffer.

Parameters:
  • float_workspace_buffer (torch.Tensor) – The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors.

  • int_workspace_buffer (torch.Tensor) – The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors.

run(q: torch.Tensor, paged_kv_cache: torch.Tensor | Tuple[torch.Tensor, torch.Tensor], *args, k_scale: float | None = None, v_scale: float | None = None, return_lse: Literal[False] = False) torch.Tensor
run(q: torch.Tensor, paged_kv_cache: torch.Tensor | Tuple[torch.Tensor, torch.Tensor], *args, k_scale: float | None = None, v_scale: float | None = None, return_lse: Literal[True] = True) Tuple[torch.Tensor, torch.Tensor]

Compute batch prefill/append attention between query and paged kv-cache.

Parameters:
  • q (torch.Tensor) – The query tensor, shape: [qo_indptr[-1], num_qo_heads, head_dim]

  • paged_kv_cache (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) –

    The paged KV-Cache stored as a tuple of tensors or a single tensor:

    • a tuple (k_cache, v_cache) of 4-D tensors, each with shape: [max_num_pages, page_size, num_kv_heads, head_dim] if kv_layout is NHD, and [max_num_pages, num_kv_heads, page_size, head_dim] if kv_layout is HND.

    • a single 5-D tensor with shape: [max_num_pages, 2, page_size, num_kv_heads, head_dim] if kv_layout is NHD, and [max_num_pages, 2, num_kv_heads, page_size, head_dim] if kv_layout is HND. Where paged_kv_cache[:, 0] is the key-cache and paged_kv_cache[:, 1] is the value-cache.

  • k_scale (Optional[float]) – The calibration scale of key for fp8 input, if not provided, will be set to 1.0.

  • v_scale (Optional[float]) – The calibration scale of value for fp8 input, if not provided, will be set to 1.0.

  • return_lse (bool) – Whether to return the logsumexp of attention output

Returns:

If return_lse is False, the attention output, shape: [qo_indptr[-1], num_qo_heads, head_dim]. If return_lse is True, a tuple of two tensors:

  • The attention output, shape: [qo_indptr[-1], num_qo_heads, head_dim].

  • The logsumexp of attention output, shape: [qo_indptr[-1], num_qo_heads].

Return type:

Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

class flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, qo_indptr_buf: torch.Tensor | None = None, kv_indptr_buf: torch.Tensor | None = None, custom_mask_buf: torch.Tensor | None = None, qk_indptr_buf: torch.Tensor | None = None, backend: str = 'auto', jit_module: Any = None)

Wrapper class for prefill/append attention with ragged (tensor) kv-cache for batch of requests.

Check our tutorial for ragged kv-cache layout.

Example

>>> import torch
>>> import flashinfer
>>> num_layers = 32
>>> num_qo_heads = 64
>>> num_kv_heads = 16
>>> head_dim = 128
>>> # allocate 128MB workspace buffer
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
>>> prefill_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
...     workspace_buffer, "NHD"
... )
>>> batch_size = 7
>>> nnz_kv = 100
>>> nnz_qo = 100
>>> qo_indptr = torch.tensor(
...     [0, 33, 44, 55, 66, 77, 88, nnz_qo], dtype=torch.int32, device="cuda:0"
... )
>>> kv_indptr = qo_indptr.clone()
>>> q_at_layer = torch.randn(num_layers, nnz_qo, num_qo_heads, head_dim).half().to("cuda:0")
>>> k_at_layer = torch.randn(num_layers, nnz_kv, num_kv_heads, head_dim).half().to("cuda:0")
>>> v_at_layer = torch.randn(num_layers, nnz_kv, num_kv_heads, head_dim).half().to("cuda:0")
>>> # create auxiliary data structures for batch prefill attention
>>> prefill_wrapper.plan(
...     qo_indptr,
...     kv_indptr,
...     num_qo_heads,
...     num_kv_heads,
...     head_dim,
...     causal=True,
... )
>>> outputs = []
>>> for i in range(num_layers):
...     q = q_at_layer[i]
...     k = k_at_layer[i]
...     v = v_at_layer[i]
...     # compute batch prefill attention, reuse auxiliary data structures
...     o = prefill_wrapper.run(q, k, v)
...     outputs.append(o)
...
>>> outputs[0].shape
torch.Size([100, 64, 128])
>>>
>>> # below is another example of creating custom mask for batch prefill attention
>>> mask_arr = []
>>> qo_len = (qo_indptr[1:] - qo_indptr[:-1]).cpu().tolist()
>>> kv_len = (kv_indptr[1:] - kv_indptr[:-1]).cpu().tolist()
>>> for i in range(batch_size):
...     mask_i = torch.tril(
...         torch.full((qo_len[i], kv_len[i]), True, device="cuda:0"),
...         diagonal=(kv_len[i] - qo_len[i]),
...     )
...     mask_arr.append(mask_i.flatten())
...
>>> mask = torch.cat(mask_arr, dim=0)
>>> prefill_wrapper.plan(
...     qo_indptr,
...     kv_indptr,
...     num_qo_heads,
...     num_kv_heads,
...     head_dim,
...     custom_mask=mask
... )
>>> outputs_custom_mask = []
>>> for i in range(num_layers):
...     q = q_at_layer[i]
...     k = k_at_layer[i]
...     v = v_at_layer[i]
...     # compute batch prefill attention, reuse auxiliary data structures
...     o_custom = prefill_wrapper.run(q, k, v)
...     assert torch.allclose(o_custom, outputs[i], rtol=1e-3, atol=1e-3)
...
>>> outputs_custom_mask[0].shape
torch.Size([100, 64, 128])

Note

To accelerate computation, FlashInfer’s batch prefill/append attention operators creates some auxiliary data structures, these data structures can be reused across multiple prefill/append attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.

__init__(float_workspace_buffer: torch.Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, qo_indptr_buf: torch.Tensor | None = None, kv_indptr_buf: torch.Tensor | None = None, custom_mask_buf: torch.Tensor | None = None, qk_indptr_buf: torch.Tensor | None = None, backend: str = 'auto', jit_module: Any = None) None

Constructor of BatchPrefillWithRaggedKVCacheWrapper.

Parameters:
  • float_workspace_buffer (torch.Tensor) – The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.

  • kv_layout (str) – The layout of the input k/v tensors, could be either NHD or HND.

  • use_cuda_graph (bool) – Whether to enable CUDA graph capture for the prefill kernels, if enabled, the auxiliary data structures will be stored as the provided buffers.

  • qo_indptr_buf (Optional[torch.Tensor]) – The user reserved GPU buffer to store the qo_indptr array, the size of the buffer should be [batch_size + 1]. This argument is only effective when use_cuda_graph is True.

  • kv_indptr_buf (Optional[torch.Tensor]) – The user reserved GPU buffer to store the kv_indptr array, the size of the buffer should be [batch_size + 1]. This argument is only effective when use_cuda_graph is True.

  • custom_mask_buf (Optional[torch.Tensor]) – The user reserved GPU buffer to store the custom mask tensor, should be large enough to store the maximum possible size of the packed custom mask tensor during the lifetime of the wrapper. This argument is only effective when use_cuda_graph is True and custom mask will be used in attention computation.

  • qk_indptr_buf (Optional[torch.Tensor]) – The user reserved GPU buffer to store the qk_indptr array, the size of the buffer should be [batch_size]. This argument is only effective when use_cuda_graph is True and custom mask will be used in attention computation.

  • backend (str) – The implementation backend, could be auto/fa2 or fa3. Defaults to auto. If set to auto, the function will automatically choose the backend based on the device architecture and kernel availability.

plan(qo_indptr: torch.Tensor, kv_indptr: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, custom_mask: torch.Tensor | None = None, packed_custom_mask: torch.Tensor | None = None, causal: bool = False, pos_encoding_mode: str = 'NONE', allow_fp16_qk_reduction: bool = False, window_left: int = -1, logits_soft_cap: float | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, q_data_type: str = 'float16', kv_data_type: str | None = None) None

Plan batch prefill/append attention on Ragged KV-Cache for given problem specification.

Parameters:
  • qo_indptr (torch.Tensor) – The indptr of the query/output tensor, shape: [batch_size + 1].

  • kv_indptr (torch.Tensor) – The indptr of the key/value tensor, shape: [batch_size + 1].

  • num_qo_heads (int) – The number of query/output heads.

  • num_kv_heads (int) – The number of key/value heads.

  • head_dim (int) – The dimension of the heads.

  • custom_mask (Optional[torch.Tensor]) –

    The flattened boolean mask tensor, shape: (sum(q_len[i] * k_len[i] for i in range(batch_size)). The elements in the mask tensor should be either True or False, where False means the corresponding element in the attention matrix will be masked out.

    Please refer to the mask layout for more details about flattened layout of mask tensor.

    When custom_mask is provided, and packed_custom_mask is not, the function will pack the custom mask tensor into a 1D packed mask tensor, which introduces additional overhead.

  • packed_custom_mask (Optional[torch.Tensor]) –

    The 1D packed uint8 mask tensor, if provided, the custom_mask will be ignored. The packed mask tensor is generated by flashinfer.quantization.packbits().

    If provided, the custom mask will be added to the attention matrix before softmax and after scaling. The mask tensor should be in the same device as the input tensors.

  • causal (bool) – Whether to apply causal mask to the attention matrix. This argument is ignored if mask is provided in plan().

  • pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be NONE/ROPE_LLAMA (LLAMA style rotary embedding) /ALIBI. Default is NONE.

  • allow_fp16_qk_reduction (bool) – Whether to use f16 for qk reduction (faster at the cost of slight precision loss).

  • window_left (int) – The left (inclusive) window size for the attention window, when set to -1, the window size will be set to the full length of the sequence. Defaults to -1.

  • logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to 0. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.

  • sm_scale (Optional[float]) – The scale used in softmax, if not provided, will be set to 1.0 / sqrt(head_dim).

  • rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to 1.0.

  • rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to 1e4.

  • q_data_type (Union[str, torch.dtype]) – The data type of the query tensor, defaults to torch.float16.

  • kv_data_type (Optional[Union[str, torch.dtype]]) – The data type of the key/value tensor. If None, will be set to q_data_type.

Note

The plan() method should be called before any run() or run_return_lse() calls, auxiliary data structures will be created during this plan call and cached for multiple kernel runs.

The num_qo_heads must be a multiple of num_kv_heads. If num_qo_heads is not equal to num_kv_heads, the function will use grouped query attention.

The plan() method cannot be used in Cuda Graph or in torch.compile.

reset_workspace_buffer(float_workspace_buffer: torch.Tensor, int_workspace_buffer) None

Reset the workspace buffer.

Parameters:
  • float_workspace_buffer (torch.Tensor) – The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors.

  • int_workspace_buffer (torch.Tensor) – The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors.

run(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, return_lse: Literal[False] = False) torch.Tensor
run(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, return_lse: Literal[True] = True) Tuple[torch.Tensor, torch.Tensor]

Compute batch prefill/append attention between query and kv-cache stored as ragged tensor.

Parameters:
  • q (torch.Tensor) – The query tensor, shape: [qo_indptr[-1], num_qo_heads, head_dim]

  • k (torch.Tensor) – The key tensor, shape: [kv_indptr[-1], num_kv_heads, head_dim]

  • v (torch.Tensor) – The value tensor, shape: [kv_indptr[-1], num_kv_heads, head_dim]

  • return_lse (bool) – Whether to return the logsumexp of attention output

Returns:

If return_lse is False, the attention output, shape: [qo_indptr[-1], num_qo_heads, head_dim]. If return_lse is True, a tuple of two tensors:

  • The attention output, shape: [qo_indptr[-1], num_qo_heads, head_dim].

  • The logsumexp of attention output, shape: [qo_indptr[-1], num_qo_heads].

Return type:

Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]