flashinfer.pod

POD (Prefill-On-Decode) attention executes a single-request prefill kernel and a batch-decode kernel concurrently in one launch, which is useful for serving stacks that overlap a chunked prefill with ongoing decode requests.

class flashinfer.pod.PODWithPagedKVCacheWrapper(float_workspace_buffer: Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, paged_kv_indptr_buffer: Tensor | None = None, paged_kv_indices_buffer: Tensor | None = None, paged_kv_last_page_len_buffer: Tensor | None = None, jit_args: List[Any] | None = None)

Wrapper class for POD-Attention with paged kv-cache (first proposed in https://arxiv.org/abs/2410.18038) for batch of requests.

Check our tutorial for page table layout.

Examples

>>> import torch
>>> import flashinfer
>>> num_layers = 32
>>> num_qo_heads = 64
>>> num_kv_heads = 8
>>> head_dim = 128
>>> max_num_pages = 128
>>> page_size = 16
>>> # allocate 128MB workspace buffer
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
>>> decode_wrapper = flashinfer.PODWithPagedKVCacheWrapper(
...     workspace_buffer, "NHD"
... )
>>> batch_size = 7
>>> kv_page_indices = torch.arange(max_num_pages).int().to("cuda:0")
>>> kv_page_indptr = torch.tensor(
...     [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0"
... )
>>> # 1 <= kv_last_page_len <= page_size
>>> kv_last_page_len = torch.tensor(
...     [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0"
... )
>>> kv_cache_at_layer = [
...     torch.randn(
...         max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
...     ) for _ in range(num_layers)
... ]
>>> # create auxiliary data structures for batch decode attention
>>> decode_wrapper.plan(
...     kv_page_indptr,
...     kv_page_indices,
...     kv_last_page_len,
...     num_qo_heads,
...     num_kv_heads,
...     head_dim,
...     page_size,
...     pos_encoding_mode="NONE",
...     data_type=torch.float16
... )
>>> outputs = []
>>> for i in range(num_layers):
...     q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0")
...     kv_cache = kv_cache_at_layer[i]
...     # compute batch decode attention, reuse auxiliary data structures for all layers
...     # TODO_AK: DEMONSTRATE USAGE OF POD
...     outputs.append(o)
...
>>> outputs[0].shape
torch.Size([7, 64, 128])

Note

To accelerate computation, FlashInfer’s POD-Attention creates some auxiliary data structures, these data structures can be reused across multiple batch decode attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.

__init__(float_workspace_buffer: Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, paged_kv_indptr_buffer: Tensor | None = None, paged_kv_indices_buffer: Tensor | None = None, paged_kv_last_page_len_buffer: Tensor | None = None, jit_args: List[Any] | None = None) None

Constructor of PODWithPagedKVCacheWrapper.

Parameters:
  • float_workspace_buffer (torch.Tensor) – The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.

  • kv_layout (str) – The layout of the input k/v tensors, could be either NHD or HND.

  • use_cuda_graph (bool) – Whether to enable CUDAGraph for batch decode attention, if enabled, the auxiliary data structures will be stored as the provided buffers. The batch_size cannot change during the lifecycle of this wrapper when CUDAGraph is enabled.

  • paged_kv_indptr_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the indptr of the paged kv cache, the size of the buffer should be [batch_size + 1]. Only needed when use_cuda_graph is True.

  • paged_kv_indices_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the page indices of the paged kv cache, should be large enough to store the maximum number of page indices (max_num_pages) during the lifecycle of this wrapper. Only needed when use_cuda_graph is True.

  • paged_kv_last_page_len_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the number of entries in the last page, the size of the buffer should be [batch_size]. Only needed when use_cuda_graph is True.

  • jit_args (Optional[List[Any]]) – If provided, the wrapper will use the provided arguments to create the JIT module, otherwise, the wrapper will use default attention implementation.

plan(indptr: Tensor, indices: Tensor, last_page_len: Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, page_size: int, pos_encoding_mode: str = 'NONE', window_left: int = -1, q_data_type: str | dtype | None = 'float16', kv_data_type: str | dtype | None = None, data_type: str | dtype | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, non_blocking: bool = True) None

Plan POD’s batch decode for given problem specification.

Parameters:
  • indptr (torch.Tensor) – The indptr of the paged kv cache, shape: [batch_size + 1]

  • indices (torch.Tensor) – The page indices of the paged kv cache, shape: [qo_indptr[-1]]

  • last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv cache, shape: [batch_size]

  • num_qo_heads (int) – The number of query/output heads

  • num_kv_heads (int) – The number of key/value heads

  • head_dim (int) – The dimension of the heads

  • page_size (int) – The page size of the paged kv cache

  • pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be NONE/ROPE_LLAMA (LLAMA style rotary embedding) /ALIBI. Defaults to NONE.

  • window_left (int) – The left (inclusive) window size for the attention window, when set to -1, the window size will be set to the full length of the sequence. Defaults to -1.

  • q_data_type (Optional[Union[str, torch.dtype]]) – The data type of the query tensor, defaults torch.float16.

  • kv_data_type (Optional[Union[str, torch.dtype]]) – The data type of the key/value tensor. If None, will be set to q_data_type. Defaults to None.

  • data_type (Optional[Union[str, torch.dtype]]) – The data type of both the query and key/value tensors. Defaults to torch.float16. data_type is deprecated, please use q_data_type and kv_data_type instead.

  • sm_scale (Optional[float]) – Softmax scale. If None, defaults to 1 / sqrt(head_dim). Cached on the wrapper and reused at run() time.

  • rope_scale (Optional[float]) – Scale factor applied during RoPE interpolation. Only consulted when pos_encoding_mode != "NONE". Defaults to 1.0 when None.

  • rope_theta (Optional[float]) – Base value for the RoPE frequencies. Only consulted when pos_encoding_mode != "NONE". Defaults to 1e4 when None.

  • non_blocking (bool) – Whether to copy the input tensors to the device asynchronously, defaults to True.

Note

The plan() method should be called before any run() or run_return_lse() calls, auxiliary data structures will be created during this call and cached for multiple run calls.

The num_qo_heads must be a multiple of num_kv_heads. If num_qo_heads is not equal to num_kv_heads, the function will use grouped query attention.

The plan() method cannot be used in Cuda Graph or in torch.compile.

reset_workspace_buffer(float_workspace_buffer: Tensor, int_workspace_buffer: Tensor) None

Reset the workspace buffer.

Parameters:
  • float_workspace_buffer (torch.Tensor) – The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors.

  • int_workspace_buffer (torch.Tensor) – The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors.

run(q_p: Tensor, k_p: Tensor, v_p: Tensor, q_d: Tensor, paged_kv_cache_d: Tensor | Tuple[Tensor, Tensor], custom_mask_p: Tensor | None = None, packed_custom_mask_p: Tensor | None = None, causal_p: bool = False, kv_layout_p: str = 'NHD', pos_encoding_mode_p: str = 'NONE', sm_scale_p: float | None = None, window_left_p: int = -1, rope_scale_p: float | None = None, rope_theta_p: float | None = None, return_lse_p: bool = False, custom_mask_d: Tensor | None = None, packed_custom_mask_d: Tensor | None = None, causal_d: bool = False, kv_layout_d: str = 'NHD', pos_encoding_mode_d: str = 'NONE', sm_scale_d: float | None = None, window_left_d: int = -1, rope_scale_d: float | None = None, rope_theta_d: float | None = None, q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, return_lse_d: bool = False, use_fp16_qk_reduction: bool = False, enable_pdl: bool | None = None, *args) Tensor | Tuple[Tensor, Tensor]

Compute POD (Prefill-On-Decode) fused attention for a batch of requests.

Single-shot fused attention that simultaneously runs single-request prefill (q_p against k_p/v_p) and batch-decode (q_d against paged_kv_cache_d) so prefill and decode tokens of a chunked-prefill / continuous-batching iteration share kernel launch and scheduling resources.

Parameters:
  • q_p (torch.Tensor) – Prefill query tensor, shape [qo_len, num_qo_heads, head_dim].

  • k_p (torch.Tensor) – Prefill key tensor. Layout matches kv_layout_p.

  • v_p (torch.Tensor) – Prefill value tensor. Layout matches kv_layout_p.

  • q_d (torch.Tensor) – Decode query tensor, shape [batch_size, num_qo_heads, head_dim].

  • paged_kv_cache_d (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) – Paged KV cache for the decode requests. Layout matches the wrapper’s kv_layout set in __init__().

  • custom_mask_p (Optional[torch.Tensor]) – Optional dense / bit-packed custom mask for the prefill side. See flashinfer.single_prefill_with_kv_cache() for the layout.

  • packed_custom_mask_p (Optional[torch.Tensor]) – Optional dense / bit-packed custom mask for the prefill side. See flashinfer.single_prefill_with_kv_cache() for the layout.

  • causal_p (bool) – Whether to apply a causal mask to the prefill side. Defaults to False.

  • kv_layout_p (str) – Layout of k_p and v_p, either "NHD" or "HND".

  • pos_encoding_mode_p (str) – Position-encoding mode for the prefill side. Defaults to "NONE".

  • sm_scale_p (Optional[float]) – Softmax scale for the prefill side. Defaults to 1 / sqrt(head_dim).

  • window_left_p (int) – Left window size for sliding-window prefill; -1 disables it.

  • rope_scale_p (Optional[float]) – RoPE scale / theta for the prefill side. Defaults to 1.0 / 1e4.

  • rope_theta_p (Optional[float]) – RoPE scale / theta for the prefill side. Defaults to 1.0 / 1e4.

  • return_lse_p (bool) – If True, allocate an LSE buffer for the prefill kernel to write into. Defaults to False. Note: the buffer is allocated and filled by the kernel but is not currently returned to the callerrun always returns just (out_p, out_d). This is a known limitation of the current POD wrapper (kernel API exists, Python wrapper does not yet plumb LSE through the return value).

  • custom_mask_d (Optional[torch.Tensor]) – Optional dense / bit-packed custom mask for the decode side.

  • packed_custom_mask_d (Optional[torch.Tensor]) – Optional dense / bit-packed custom mask for the decode side.

  • causal_d (bool) – Whether to apply a causal mask to the decode side. Defaults to False.

  • kv_layout_d (str) – Currently ignored: the decode KV layout is always taken from the wrapper’s kv_layout (set in __init__()); this argument is accepted for signature symmetry with kv_layout_p but the value is not consulted by the kernel.

  • pos_encoding_mode_d (str) – Currently ignored: overridden by self._pos_encoding_mode from plan().

  • sm_scale_d (Optional[float]) – Currently ignored: overridden by self._sm_scale from plan() (which itself defaults to 1 / sqrt(head_dim)).

  • window_left_d (int) – Currently ignored: overridden by self._window_left from plan().

  • rope_scale_d (Optional[float]) – Currently ignored: overridden by self._rope_scale / self._rope_theta from plan().

  • rope_theta_d (Optional[float]) – Currently ignored: overridden by self._rope_scale / self._rope_theta from plan().

  • q_scale (Optional[float]) – FP8 calibration scales applied to the decode side. Folded into the decode sm_scale (q_scale, k_scale) or the kernel output (v_scale).

  • k_scale (Optional[float]) – FP8 calibration scales applied to the decode side. Folded into the decode sm_scale (q_scale, k_scale) or the kernel output (v_scale).

  • v_scale (Optional[float]) – FP8 calibration scales applied to the decode side. Folded into the decode sm_scale (q_scale, k_scale) or the kernel output (v_scale).

  • return_lse_d (bool) – If True, allocate an LSE buffer for the decode kernel to write into. Defaults to False. See return_lse_p – same caveat: allocated but not returned.

  • use_fp16_qk_reduction (bool) – Whether to accumulate QK in FP16 (lower precision, higher throughput). Defaults to False.

  • enable_pdl (Optional[bool]) – Programmatic Dependent Launch toggle. When None (default), the wrapper auto-detects support from the query device.

  • *args – Reserved for forward-compat with future kernel parameters.

Returns:

(out_p, out_d): the prefill output (shape [qo_len, num_qo_heads, head_dim]) and the decode output (shape [batch_size, num_qo_heads, head_dim]). LSE tensors are not part of the return value even when return_lse_p / return_lse_d is True (see those parameter notes).

Return type:

Tuple[torch.Tensor, torch.Tensor]

class flashinfer.pod.BatchPODWithPagedKVCacheWrapper(float_workspace_buffer: Tensor, kv_layout: str = 'NHD')

Wrapper class for POD-Attention with paged kv-cache (first proposed in https://arxiv.org/abs/2410.18038) for batch of requests.

Check our tutorial for page table layout.

Examples

>>> import torch
>>> import flashinfer
>>> num_layers = 8
>>> num_qo_heads = 64
>>> num_kv_heads = 8
>>> head_dim = 128
>>> max_num_pages = 128
>>> device = 0
>>> page_block_size = 1
>>> causal = True
>>> # allocate 128MB workspace buffer
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
>>> wrapper = flashinfer.BatchPODWithPagedKVCacheWrapper(
...     workspace_buffer, "NHD"
... )
>>> # Prefill and decode parameters
>>> p_qo_lens = [2048] * 2
>>> d_qo_lens = [1] * 128
>>> p_kv_lens = [2048] * 2
>>> d_kv_lens = [2048] * 128
>>> # Prefill plan inputs
>>> p_seq_lens_blocks = torch.ceil(
...     torch.tensor(p_kv_lens, dtype=torch.int32) / page_block_size
... ).int()
>>> p_q_indptr = torch.cat(
...     [torch.tensor([0]), torch.cumsum(torch.tensor(p_qo_lens), 0)], dim=0
... ).int()
>>> p_kv_indptr = torch.cat(
...     [torch.tensor([0]), torch.cumsum(p_seq_lens_blocks, 0)], dim=0
... ).int()
>>> kv_indices_p = torch.arange(0, p_kv_indptr[-1], device=device, dtype=torch.int32)
>>> last_page_len_p = (p_seq_lens_blocks - 1) % page_block_size + 1
>>> # Decode plan inputs
>>> d_seq_lens_blocks = torch.ceil(
...     torch.tensor(d_kv_lens, dtype=torch.int32) / page_block_size
... ).int()
>>> d_q_indptr = torch.cat(
...     [torch.tensor([0]), torch.cumsum(torch.tensor(d_qo_lens), 0)], dim=0
... ).int()
>>> d_kv_indptr = torch.cat(
...     [torch.tensor([0]), torch.cumsum(d_seq_lens_blocks, 0)], dim=0
... ).int()
>>> kv_indices_d = torch.arange(0, d_kv_indptr[-1], device=device, dtype=torch.int32)
>>> last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1
>>> # create auxiliary data structures for batch decode attention
>>> wrapper.plan(
...     # Prefill params
...     p_q_indptr.to(device),
...     p_kv_indptr.to(device),
...     kv_indices_p.to(device),
...     last_page_len_p,
...     # Decode params
...     d_q_indptr.to(device),
...     d_kv_indptr.to(device),
...     kv_indices_d.to(device),
...     last_page_len_d,
...     # Common params
...     num_qo_heads=num_qo_heads,
...     num_kv_heads=num_kv_heads,
...     head_dim=head_dim,
...     page_size=page_block_size,
...     q_data_type=torch.bfloat16,
...     kv_data_type=torch.bfloat16,
... )
>>> # Prefill input tensors
>>> q_p = torch.rand(p_q_indptr[-1].item(), num_qo_heads, head_dim).to(
...     device, dtype=torch.bfloat16
... )
>>> kv_p = torch.randn(p_kv_indptr[-1], 2, page_block_size, num_kv_heads, head_dim).to(
...     device, dtype=torch.bfloat16
... ).unbind(1)
>>> # Decode input tensors
>>> q_d = torch.rand(d_q_indptr[-1].item(), num_qo_heads, head_dim).to(
...     device, dtype=torch.bfloat16
... )
>>> kv_d = torch.randn(d_kv_indptr[-1], 2, page_block_size, num_kv_heads, head_dim).to(
...     device, dtype=torch.bfloat16
... ).unbind(1)
>>> for i in range(num_layers):
...     o_p_batch, o_d_batch = wrapper.run(
...         q_p,
...         kv_p,
...         q_d,
...         kv_d,
...         causal_p=causal,
...     )
>>> print(o_p_batch.shape, o_d_batch.shape)
torch.Size([4096, 64, 128]) torch.Size([128, 64, 128])

Note

To accelerate computation, FlashInfer’s POD-Attention creates some auxiliary data structures, these data structures can be reused across multiple batch decode attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.

__init__(float_workspace_buffer: Tensor, kv_layout: str = 'NHD') None

Constructor of BatchPODWithPagedKVCacheWrapper.

Parameters:
  • float_workspace_buffer (torch.Tensor) – The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.

  • kv_layout (str) – The layout of the input k/v tensors, could be either NHD or HND.

plan(qo_indptr_p: Tensor, kv_indptr_p: Tensor, kv_indices_p: Tensor, last_page_len_p: Tensor, qo_indptr_d: Tensor, kv_indptr_d: Tensor, kv_indices_d: Tensor, last_page_len_d: Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, page_size: int, pos_encoding_mode: str = 'NONE', window_left: int = -1, q_data_type: str | dtype | None = 'float16', kv_data_type: str | dtype | None = None, data_type: str | dtype | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, non_blocking: bool = True) None

Plan POD’s batch prefill and decode for given problem specification.

Parameters:
  • qo_indptr_p (torch.Tensor) – The prefill indptr of the query/output tensor, shape: [batch_size + 1].

  • kv_indptr_p (torch.Tensor) – The prefill indptr of the paged kv-cache, shape: [batch_size + 1].

  • kv_indices_p (torch.Tensor) – The prefill page indices of the paged kv-cache, shape: [kv_indptr[-1]].

  • last_page_len_p (torch.Tensor) – The number of entries in the last page of each prefill request in the paged kv-cache, shape: [batch_size].

  • qo_indptr_d (torch.Tensor) – The decode indptr of the query/output tensor, shape: [batch_size + 1].

  • kv_indptr_d (torch.Tensor) – The decode indptr of the paged kv-cache, shape: [batch_size + 1].

  • kv_indices_d (torch.Tensor) – The decode page indices of the paged kv-cache, shape: [kv_indptr[-1]].

  • last_page_len_d (torch.Tensor) – The number of entries in the last page of each decode request in the paged kv-cache, shape: [batch_size].

  • num_qo_heads (int) – The number of query/output heads

  • num_kv_heads (int) – The number of key/value heads

  • head_dim (int) – The dimension of the heads

  • page_size (int) – The page size of the paged kv cache

  • pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be NONE/ROPE_LLAMA (LLAMA style rotary embedding) /ALIBI. Defaults to NONE.

  • window_left (int) – The left (inclusive) window size for the attention window, when set to -1, the window size will be set to the full length of the sequence. Defaults to -1.

  • q_data_type (Optional[Union[str, torch.dtype]]) – The data type of the query tensor, defaults torch.float16.

  • kv_data_type (Optional[Union[str, torch.dtype]]) – The data type of the key/value tensor. If None, will be set to q_data_type. Defaults to None.

  • data_type (Optional[Union[str, torch.dtype]]) – The data type of both the query and key/value tensors. Defaults to torch.float16. data_type is deprecated, please use q_data_type and kv_data_type instead.

  • sm_scale (Optional[float]) – The scale used in softmax, if not provided, will be set to 1.0 / sqrt(head_dim_qk).

  • rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to 1.0.

  • rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to 1e4.

  • non_blocking (bool) – Whether to copy the input tensors to the device asynchronously, defaults to True.

Note

The plan() method should be called before any run() or run_return_lse() calls, auxiliary data structures will be created during this call and cached for multiple run calls.

The num_qo_heads must be a multiple of num_kv_heads. If num_qo_heads is not equal to num_kv_heads, the function will use grouped query attention.

The plan() method cannot be used in Cuda Graph or in torch.compile.

run(q_p: Tensor, paged_kv_cache_p: Tensor | Tuple[Tensor, Tensor], q_d: Tensor, paged_kv_cache_d: Tensor | Tuple[Tensor, Tensor], custom_mask_p: Tensor | None = None, packed_custom_mask_p: Tensor | None = None, causal_p: bool = False, q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, return_lse: bool = False, use_fp16_qk_reduction: bool = False, enable_pdl: bool | None = None) Tuple[Tensor, Tensor] | Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]

Compute batched POD (Prefill-On-Decode) fused attention.

Fused single-shot attention that runs batched paged prefill (q_p against paged_kv_cache_p) and batched paged decode (q_d against paged_kv_cache_d) in the same kernel launch, sharing scheduling and execution resources. All prefill / decode shape and policy parameters (pos-encoding, sliding window, sm_scale, RoPE, etc.) are taken from the cached values supplied to plan().

Parameters:
  • q_p (torch.Tensor) – Prefill query tensor, shape [qo_indptr_p[-1], num_qo_heads, head_dim].

  • paged_kv_cache_p (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) – Paged KV cache for the prefill requests. Layout matches kv_layout set in __init__().

  • q_d (torch.Tensor) – Decode query tensor, shape [batch_size_d, num_qo_heads, head_dim].

  • paged_kv_cache_d (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) – Paged KV cache for the decode requests. Layout matches kv_layout set in __init__().

  • custom_mask_p (Optional[torch.Tensor]) – Optional dense custom mask for the prefill side (auto-packed when packed_custom_mask_p is None).

  • packed_custom_mask_p (Optional[torch.Tensor]) – Optional bit-packed custom mask for the prefill side.

  • causal_p (bool) – Whether to apply a causal mask to the prefill side. Defaults to False.

  • q_scale (Optional[float]) – FP8 calibration scales applied to the decode side. Folded into the decode sm_scale (q_scale, k_scale) or the kernel output (v_scale).

  • k_scale (Optional[float]) – FP8 calibration scales applied to the decode side. Folded into the decode sm_scale (q_scale, k_scale) or the kernel output (v_scale).

  • v_scale (Optional[float]) – FP8 calibration scales applied to the decode side. Folded into the decode sm_scale (q_scale, k_scale) or the kernel output (v_scale).

  • return_lse (bool) – Whether to return the LSE tensors for both prefill and decode. Defaults to False.

  • use_fp16_qk_reduction (bool) – Whether to accumulate QK in FP16 (lower precision, higher throughput). Defaults to False.

  • enable_pdl (Optional[bool]) – Programmatic Dependent Launch toggle. When None (default), the wrapper auto-detects support from the query device.

Returns:

By default (out_p, out_d): the prefill output and the decode output. When return_lse is True the return becomes ((out_p, lse_p), (out_d, lse_d)).

Return type:

Union[Tuple[torch.Tensor, torch.Tensor], Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]]