flashinfer.pod¶
POD (Prefill-On-Decode) attention executes a single-request prefill kernel and a batch-decode kernel concurrently in one launch, which is useful for serving stacks that overlap a chunked prefill with ongoing decode requests.
- class flashinfer.pod.PODWithPagedKVCacheWrapper(float_workspace_buffer: Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, paged_kv_indptr_buffer: Tensor | None = None, paged_kv_indices_buffer: Tensor | None = None, paged_kv_last_page_len_buffer: Tensor | None = None, jit_args: List[Any] | None = None)¶
Wrapper class for POD-Attention with paged kv-cache (first proposed in https://arxiv.org/abs/2410.18038) for batch of requests.
Check our tutorial for page table layout.
Examples
>>> import torch >>> import flashinfer >>> num_layers = 32 >>> num_qo_heads = 64 >>> num_kv_heads = 8 >>> head_dim = 128 >>> max_num_pages = 128 >>> page_size = 16 >>> # allocate 128MB workspace buffer >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") >>> decode_wrapper = flashinfer.PODWithPagedKVCacheWrapper( ... workspace_buffer, "NHD" ... ) >>> batch_size = 7 >>> kv_page_indices = torch.arange(max_num_pages).int().to("cuda:0") >>> kv_page_indptr = torch.tensor( ... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0" ... ) >>> # 1 <= kv_last_page_len <= page_size >>> kv_last_page_len = torch.tensor( ... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0" ... ) >>> kv_cache_at_layer = [ ... torch.randn( ... max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ... ) for _ in range(num_layers) ... ] >>> # create auxiliary data structures for batch decode attention >>> decode_wrapper.plan( ... kv_page_indptr, ... kv_page_indices, ... kv_last_page_len, ... num_qo_heads, ... num_kv_heads, ... head_dim, ... page_size, ... pos_encoding_mode="NONE", ... data_type=torch.float16 ... ) >>> outputs = [] >>> for i in range(num_layers): ... q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0") ... kv_cache = kv_cache_at_layer[i] ... # compute batch decode attention, reuse auxiliary data structures for all layers ... # TODO_AK: DEMONSTRATE USAGE OF POD ... outputs.append(o) ... >>> outputs[0].shape torch.Size([7, 64, 128])
Note
To accelerate computation, FlashInfer’s POD-Attention creates some auxiliary data structures, these data structures can be reused across multiple batch decode attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.
- __init__(float_workspace_buffer: Tensor, kv_layout: str = 'NHD', use_cuda_graph: bool = False, paged_kv_indptr_buffer: Tensor | None = None, paged_kv_indices_buffer: Tensor | None = None, paged_kv_last_page_len_buffer: Tensor | None = None, jit_args: List[Any] | None = None) None¶
Constructor of
PODWithPagedKVCacheWrapper.- Parameters:
float_workspace_buffer (torch.Tensor) – The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
kv_layout (str) – The layout of the input k/v tensors, could be either
NHDorHND.use_cuda_graph (bool) – Whether to enable CUDAGraph for batch decode attention, if enabled, the auxiliary data structures will be stored as the provided buffers. The
batch_sizecannot change during the lifecycle of this wrapper when CUDAGraph is enabled.paged_kv_indptr_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the indptr of the paged kv cache, the size of the buffer should be
[batch_size + 1]. Only needed whenuse_cuda_graphisTrue.paged_kv_indices_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the page indices of the paged kv cache, should be large enough to store the maximum number of page indices (
max_num_pages) during the lifecycle of this wrapper. Only needed whenuse_cuda_graphisTrue.paged_kv_last_page_len_buffer (Optional[torch.Tensor]) – The user reserved buffer on GPU to store the number of entries in the last page, the size of the buffer should be
[batch_size]. Only needed whenuse_cuda_graphisTrue.jit_args (Optional[List[Any]]) – If provided, the wrapper will use the provided arguments to create the JIT module, otherwise, the wrapper will use default attention implementation.
- plan(indptr: Tensor, indices: Tensor, last_page_len: Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, page_size: int, pos_encoding_mode: str = 'NONE', window_left: int = -1, q_data_type: str | dtype | None = 'float16', kv_data_type: str | dtype | None = None, data_type: str | dtype | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, non_blocking: bool = True) None¶
Plan POD’s batch decode for given problem specification.
- Parameters:
indptr (torch.Tensor) – The indptr of the paged kv cache, shape:
[batch_size + 1]indices (torch.Tensor) – The page indices of the paged kv cache, shape:
[qo_indptr[-1]]last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv cache, shape:
[batch_size]num_qo_heads (int) – The number of query/output heads
num_kv_heads (int) – The number of key/value heads
head_dim (int) – The dimension of the heads
page_size (int) – The page size of the paged kv cache
pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be
NONE/ROPE_LLAMA(LLAMA style rotary embedding) /ALIBI. Defaults toNONE.window_left (int) – The left (inclusive) window size for the attention window, when set to
-1, the window size will be set to the full length of the sequence. Defaults to-1.q_data_type (Optional[Union[str, torch.dtype]]) – The data type of the query tensor, defaults torch.float16.
kv_data_type (Optional[Union[str, torch.dtype]]) – The data type of the key/value tensor. If None, will be set to
q_data_type. Defaults toNone.data_type (Optional[Union[str, torch.dtype]]) – The data type of both the query and key/value tensors. Defaults to torch.float16. data_type is deprecated, please use q_data_type and kv_data_type instead.
sm_scale (Optional[float]) – Softmax scale. If
None, defaults to1 / sqrt(head_dim). Cached on the wrapper and reused atrun()time.rope_scale (Optional[float]) – Scale factor applied during RoPE interpolation. Only consulted when
pos_encoding_mode != "NONE". Defaults to1.0whenNone.rope_theta (Optional[float]) – Base value for the RoPE frequencies. Only consulted when
pos_encoding_mode != "NONE". Defaults to1e4whenNone.non_blocking (bool) – Whether to copy the input tensors to the device asynchronously, defaults to
True.
Note
The
plan()method should be called before anyrun()orrun_return_lse()calls, auxiliary data structures will be created during this call and cached for multiple run calls.The
num_qo_headsmust be a multiple ofnum_kv_heads. Ifnum_qo_headsis not equal tonum_kv_heads, the function will use grouped query attention.The
plan()method cannot be used in Cuda Graph or intorch.compile.
- reset_workspace_buffer(float_workspace_buffer: Tensor, int_workspace_buffer: Tensor) None¶
Reset the workspace buffer.
- Parameters:
float_workspace_buffer (torch.Tensor) – The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors.
int_workspace_buffer (torch.Tensor) – The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors.
- run(q_p: Tensor, k_p: Tensor, v_p: Tensor, q_d: Tensor, paged_kv_cache_d: Tensor | Tuple[Tensor, Tensor], custom_mask_p: Tensor | None = None, packed_custom_mask_p: Tensor | None = None, causal_p: bool = False, kv_layout_p: str = 'NHD', pos_encoding_mode_p: str = 'NONE', sm_scale_p: float | None = None, window_left_p: int = -1, rope_scale_p: float | None = None, rope_theta_p: float | None = None, return_lse_p: bool = False, custom_mask_d: Tensor | None = None, packed_custom_mask_d: Tensor | None = None, causal_d: bool = False, kv_layout_d: str = 'NHD', pos_encoding_mode_d: str = 'NONE', sm_scale_d: float | None = None, window_left_d: int = -1, rope_scale_d: float | None = None, rope_theta_d: float | None = None, q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, return_lse_d: bool = False, use_fp16_qk_reduction: bool = False, enable_pdl: bool | None = None, *args) Tensor | Tuple[Tensor, Tensor]¶
Compute POD (Prefill-On-Decode) fused attention for a batch of requests.
Single-shot fused attention that simultaneously runs single-request prefill (
q_pagainstk_p/v_p) and batch-decode (q_dagainstpaged_kv_cache_d) so prefill and decode tokens of a chunked-prefill / continuous-batching iteration share kernel launch and scheduling resources.- Parameters:
q_p (torch.Tensor) – Prefill query tensor, shape
[qo_len, num_qo_heads, head_dim].k_p (torch.Tensor) – Prefill key tensor. Layout matches
kv_layout_p.v_p (torch.Tensor) – Prefill value tensor. Layout matches
kv_layout_p.q_d (torch.Tensor) – Decode query tensor, shape
[batch_size, num_qo_heads, head_dim].paged_kv_cache_d (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) – Paged KV cache for the decode requests. Layout matches the wrapper’s
kv_layoutset in__init__().custom_mask_p (Optional[torch.Tensor]) – Optional dense / bit-packed custom mask for the prefill side. See
flashinfer.single_prefill_with_kv_cache()for the layout.packed_custom_mask_p (Optional[torch.Tensor]) – Optional dense / bit-packed custom mask for the prefill side. See
flashinfer.single_prefill_with_kv_cache()for the layout.causal_p (bool) – Whether to apply a causal mask to the prefill side. Defaults to
False.kv_layout_p (str) – Layout of
k_pandv_p, either"NHD"or"HND".pos_encoding_mode_p (str) – Position-encoding mode for the prefill side. Defaults to
"NONE".sm_scale_p (Optional[float]) – Softmax scale for the prefill side. Defaults to
1 / sqrt(head_dim).window_left_p (int) – Left window size for sliding-window prefill;
-1disables it.rope_scale_p (Optional[float]) – RoPE scale / theta for the prefill side. Defaults to
1.0/1e4.rope_theta_p (Optional[float]) – RoPE scale / theta for the prefill side. Defaults to
1.0/1e4.return_lse_p (bool) – If
True, allocate an LSE buffer for the prefill kernel to write into. Defaults toFalse. Note: the buffer is allocated and filled by the kernel but is not currently returned to the caller –runalways returns just(out_p, out_d). This is a known limitation of the current POD wrapper (kernel API exists, Python wrapper does not yet plumb LSE through the return value).custom_mask_d (Optional[torch.Tensor]) – Optional dense / bit-packed custom mask for the decode side.
packed_custom_mask_d (Optional[torch.Tensor]) – Optional dense / bit-packed custom mask for the decode side.
causal_d (bool) – Whether to apply a causal mask to the decode side. Defaults to
False.kv_layout_d (str) – Currently ignored: the decode KV layout is always taken from the wrapper’s
kv_layout(set in__init__()); this argument is accepted for signature symmetry withkv_layout_pbut the value is not consulted by the kernel.pos_encoding_mode_d (str) – Currently ignored: overridden by
self._pos_encoding_modefromplan().sm_scale_d (Optional[float]) – Currently ignored: overridden by
self._sm_scalefromplan()(which itself defaults to1 / sqrt(head_dim)).window_left_d (int) – Currently ignored: overridden by
self._window_leftfromplan().rope_scale_d (Optional[float]) – Currently ignored: overridden by
self._rope_scale/self._rope_thetafromplan().rope_theta_d (Optional[float]) – Currently ignored: overridden by
self._rope_scale/self._rope_thetafromplan().q_scale (Optional[float]) – FP8 calibration scales applied to the decode side. Folded into the decode
sm_scale(q_scale,k_scale) or the kernel output (v_scale).k_scale (Optional[float]) – FP8 calibration scales applied to the decode side. Folded into the decode
sm_scale(q_scale,k_scale) or the kernel output (v_scale).v_scale (Optional[float]) – FP8 calibration scales applied to the decode side. Folded into the decode
sm_scale(q_scale,k_scale) or the kernel output (v_scale).return_lse_d (bool) – If
True, allocate an LSE buffer for the decode kernel to write into. Defaults toFalse. Seereturn_lse_p– same caveat: allocated but not returned.use_fp16_qk_reduction (bool) – Whether to accumulate
QKin FP16 (lower precision, higher throughput). Defaults toFalse.enable_pdl (Optional[bool]) – Programmatic Dependent Launch toggle. When
None(default), the wrapper auto-detects support from the query device.*args – Reserved for forward-compat with future kernel parameters.
- Returns:
(out_p, out_d): the prefill output (shape[qo_len, num_qo_heads, head_dim]) and the decode output (shape[batch_size, num_qo_heads, head_dim]). LSE tensors are not part of the return value even whenreturn_lse_p/return_lse_disTrue(see those parameter notes).- Return type:
Tuple[torch.Tensor, torch.Tensor]
- class flashinfer.pod.BatchPODWithPagedKVCacheWrapper(float_workspace_buffer: Tensor, kv_layout: str = 'NHD')¶
Wrapper class for POD-Attention with paged kv-cache (first proposed in https://arxiv.org/abs/2410.18038) for batch of requests.
Check our tutorial for page table layout.
Examples
>>> import torch >>> import flashinfer >>> num_layers = 8 >>> num_qo_heads = 64 >>> num_kv_heads = 8 >>> head_dim = 128 >>> max_num_pages = 128 >>> device = 0 >>> page_block_size = 1 >>> causal = True >>> # allocate 128MB workspace buffer >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") >>> wrapper = flashinfer.BatchPODWithPagedKVCacheWrapper( ... workspace_buffer, "NHD" ... ) >>> # Prefill and decode parameters >>> p_qo_lens = [2048] * 2 >>> d_qo_lens = [1] * 128 >>> p_kv_lens = [2048] * 2 >>> d_kv_lens = [2048] * 128 >>> # Prefill plan inputs >>> p_seq_lens_blocks = torch.ceil( ... torch.tensor(p_kv_lens, dtype=torch.int32) / page_block_size ... ).int() >>> p_q_indptr = torch.cat( ... [torch.tensor([0]), torch.cumsum(torch.tensor(p_qo_lens), 0)], dim=0 ... ).int() >>> p_kv_indptr = torch.cat( ... [torch.tensor([0]), torch.cumsum(p_seq_lens_blocks, 0)], dim=0 ... ).int() >>> kv_indices_p = torch.arange(0, p_kv_indptr[-1], device=device, dtype=torch.int32) >>> last_page_len_p = (p_seq_lens_blocks - 1) % page_block_size + 1 >>> # Decode plan inputs >>> d_seq_lens_blocks = torch.ceil( ... torch.tensor(d_kv_lens, dtype=torch.int32) / page_block_size ... ).int() >>> d_q_indptr = torch.cat( ... [torch.tensor([0]), torch.cumsum(torch.tensor(d_qo_lens), 0)], dim=0 ... ).int() >>> d_kv_indptr = torch.cat( ... [torch.tensor([0]), torch.cumsum(d_seq_lens_blocks, 0)], dim=0 ... ).int() >>> kv_indices_d = torch.arange(0, d_kv_indptr[-1], device=device, dtype=torch.int32) >>> last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1 >>> # create auxiliary data structures for batch decode attention >>> wrapper.plan( ... # Prefill params ... p_q_indptr.to(device), ... p_kv_indptr.to(device), ... kv_indices_p.to(device), ... last_page_len_p, ... # Decode params ... d_q_indptr.to(device), ... d_kv_indptr.to(device), ... kv_indices_d.to(device), ... last_page_len_d, ... # Common params ... num_qo_heads=num_qo_heads, ... num_kv_heads=num_kv_heads, ... head_dim=head_dim, ... page_size=page_block_size, ... q_data_type=torch.bfloat16, ... kv_data_type=torch.bfloat16, ... ) >>> # Prefill input tensors >>> q_p = torch.rand(p_q_indptr[-1].item(), num_qo_heads, head_dim).to( ... device, dtype=torch.bfloat16 ... ) >>> kv_p = torch.randn(p_kv_indptr[-1], 2, page_block_size, num_kv_heads, head_dim).to( ... device, dtype=torch.bfloat16 ... ).unbind(1) >>> # Decode input tensors >>> q_d = torch.rand(d_q_indptr[-1].item(), num_qo_heads, head_dim).to( ... device, dtype=torch.bfloat16 ... ) >>> kv_d = torch.randn(d_kv_indptr[-1], 2, page_block_size, num_kv_heads, head_dim).to( ... device, dtype=torch.bfloat16 ... ).unbind(1) >>> for i in range(num_layers): ... o_p_batch, o_d_batch = wrapper.run( ... q_p, ... kv_p, ... q_d, ... kv_d, ... causal_p=causal, ... ) >>> print(o_p_batch.shape, o_d_batch.shape) torch.Size([4096, 64, 128]) torch.Size([128, 64, 128])
Note
To accelerate computation, FlashInfer’s POD-Attention creates some auxiliary data structures, these data structures can be reused across multiple batch decode attention calls (e.g. different Transformer layers). This wrapper class manages the lifecycle of these data structures.
- __init__(float_workspace_buffer: Tensor, kv_layout: str = 'NHD') None¶
Constructor of
BatchPODWithPagedKVCacheWrapper.- Parameters:
float_workspace_buffer (torch.Tensor) – The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
kv_layout (str) – The layout of the input k/v tensors, could be either
NHDorHND.
- plan(qo_indptr_p: Tensor, kv_indptr_p: Tensor, kv_indices_p: Tensor, last_page_len_p: Tensor, qo_indptr_d: Tensor, kv_indptr_d: Tensor, kv_indices_d: Tensor, last_page_len_d: Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, page_size: int, pos_encoding_mode: str = 'NONE', window_left: int = -1, q_data_type: str | dtype | None = 'float16', kv_data_type: str | dtype | None = None, data_type: str | dtype | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, non_blocking: bool = True) None¶
Plan POD’s batch prefill and decode for given problem specification.
- Parameters:
qo_indptr_p (torch.Tensor) – The prefill indptr of the query/output tensor, shape:
[batch_size + 1].kv_indptr_p (torch.Tensor) – The prefill indptr of the paged kv-cache, shape:
[batch_size + 1].kv_indices_p (torch.Tensor) – The prefill page indices of the paged kv-cache, shape:
[kv_indptr[-1]].last_page_len_p (torch.Tensor) – The number of entries in the last page of each prefill request in the paged kv-cache, shape:
[batch_size].qo_indptr_d (torch.Tensor) – The decode indptr of the query/output tensor, shape:
[batch_size + 1].kv_indptr_d (torch.Tensor) – The decode indptr of the paged kv-cache, shape:
[batch_size + 1].kv_indices_d (torch.Tensor) – The decode page indices of the paged kv-cache, shape:
[kv_indptr[-1]].last_page_len_d (torch.Tensor) – The number of entries in the last page of each decode request in the paged kv-cache, shape:
[batch_size].num_qo_heads (int) – The number of query/output heads
num_kv_heads (int) – The number of key/value heads
head_dim (int) – The dimension of the heads
page_size (int) – The page size of the paged kv cache
pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be
NONE/ROPE_LLAMA(LLAMA style rotary embedding) /ALIBI. Defaults toNONE.window_left (int) – The left (inclusive) window size for the attention window, when set to
-1, the window size will be set to the full length of the sequence. Defaults to-1.q_data_type (Optional[Union[str, torch.dtype]]) – The data type of the query tensor, defaults torch.float16.
kv_data_type (Optional[Union[str, torch.dtype]]) – The data type of the key/value tensor. If None, will be set to
q_data_type. Defaults toNone.data_type (Optional[Union[str, torch.dtype]]) – The data type of both the query and key/value tensors. Defaults to torch.float16. data_type is deprecated, please use q_data_type and kv_data_type instead.
sm_scale (Optional[float]) – The scale used in softmax, if not provided, will be set to
1.0 / sqrt(head_dim_qk).rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to
1.0.rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to
1e4.non_blocking (bool) – Whether to copy the input tensors to the device asynchronously, defaults to
True.
Note
The
plan()method should be called before anyrun()orrun_return_lse()calls, auxiliary data structures will be created during this call and cached for multiple run calls.The
num_qo_headsmust be a multiple ofnum_kv_heads. Ifnum_qo_headsis not equal tonum_kv_heads, the function will use grouped query attention.The
plan()method cannot be used in Cuda Graph or intorch.compile.
- run(q_p: Tensor, paged_kv_cache_p: Tensor | Tuple[Tensor, Tensor], q_d: Tensor, paged_kv_cache_d: Tensor | Tuple[Tensor, Tensor], custom_mask_p: Tensor | None = None, packed_custom_mask_p: Tensor | None = None, causal_p: bool = False, q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, return_lse: bool = False, use_fp16_qk_reduction: bool = False, enable_pdl: bool | None = None) Tuple[Tensor, Tensor] | Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]¶
Compute batched POD (Prefill-On-Decode) fused attention.
Fused single-shot attention that runs batched paged prefill (
q_pagainstpaged_kv_cache_p) and batched paged decode (q_dagainstpaged_kv_cache_d) in the same kernel launch, sharing scheduling and execution resources. All prefill / decode shape and policy parameters (pos-encoding, sliding window, sm_scale, RoPE, etc.) are taken from the cached values supplied toplan().- Parameters:
q_p (torch.Tensor) – Prefill query tensor, shape
[qo_indptr_p[-1], num_qo_heads, head_dim].paged_kv_cache_p (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) – Paged KV cache for the prefill requests. Layout matches
kv_layoutset in__init__().q_d (torch.Tensor) – Decode query tensor, shape
[batch_size_d, num_qo_heads, head_dim].paged_kv_cache_d (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) – Paged KV cache for the decode requests. Layout matches
kv_layoutset in__init__().custom_mask_p (Optional[torch.Tensor]) – Optional dense custom mask for the prefill side (auto-packed when
packed_custom_mask_pisNone).packed_custom_mask_p (Optional[torch.Tensor]) – Optional bit-packed custom mask for the prefill side.
causal_p (bool) – Whether to apply a causal mask to the prefill side. Defaults to
False.q_scale (Optional[float]) – FP8 calibration scales applied to the decode side. Folded into the decode
sm_scale(q_scale,k_scale) or the kernel output (v_scale).k_scale (Optional[float]) – FP8 calibration scales applied to the decode side. Folded into the decode
sm_scale(q_scale,k_scale) or the kernel output (v_scale).v_scale (Optional[float]) – FP8 calibration scales applied to the decode side. Folded into the decode
sm_scale(q_scale,k_scale) or the kernel output (v_scale).return_lse (bool) – Whether to return the LSE tensors for both prefill and decode. Defaults to
False.use_fp16_qk_reduction (bool) – Whether to accumulate
QKin FP16 (lower precision, higher throughput). Defaults toFalse.enable_pdl (Optional[bool]) – Programmatic Dependent Launch toggle. When
None(default), the wrapper auto-detects support from the query device.
- Returns:
By default
(out_p, out_d): the prefill output and the decode output. Whenreturn_lseisTruethe return becomes((out_p, lse_p), (out_d, lse_d)).- Return type:
Union[Tuple[torch.Tensor, torch.Tensor], Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]]