flashinfer.gdn_decode.gated_delta_rule_decode_pretranspose¶
- flashinfer.gdn_decode.gated_delta_rule_decode_pretranspose(q: Tensor, k: Tensor, v: Tensor, state: Tensor | None, A_log: Tensor, a: Tensor, dt_bias: Tensor, b: Tensor, scale: float | None = None, output: Tensor | None = None, use_qk_l2norm: bool = True, initial_state: Tensor | None = None, initial_state_indices: Tensor | None = None, output_state_indices: Tensor | None = None) Tuple[Tensor, Tensor]¶
Gated Delta Rule Decode kernel for single-token generation.
Implements the decode phase of gated delta rule linear attention, processing one token at a time and updating the recurrent state.
- Parameters:
q (torch.Tensor) – Current query of shape
[B, 1, H, K]. Must be float16/bfloat16.k (torch.Tensor) – Current key of shape
[B, 1, H, K]. Must be float16/bfloat16.v (torch.Tensor) – Current value of shape
[B, 1, HV, V]. Must be float16/bfloat16.state (torch.Tensor, optional) – Current state of shape
[B, HV, V, K](v-major / K-last layout). Float32: legacy kernel (T=1 only). Bfloat16: BF16 state backend (T=1 or MTP for T>1) when K=V=128. Updated in-place. PassNonewhen usinginitial_state/initial_state_indicesinstead.A_log (torch.Tensor) – Log decay parameter of shape
[HV]. Must be float32.a (torch.Tensor) – Input-dependent decay of shape
[B, 1, HV]. Must be float16/bfloat16.dt_bias (torch.Tensor) – Decay bias of shape
[HV]. Must be bfloat16 or float32.b (torch.Tensor) – Update gate (beta) input of shape
[B, 1, HV]. Must be float16/bfloat16.scale (float, optional) – Scale factor for queries. If
None, defaults to1 / sqrt(K).output (torch.Tensor, optional) – Pre-allocated output tensor of shape
[B, 1, HV, V]. Allocated automatically whenNone.use_qk_l2norm (bool) – Whether to apply L2 normalization to q and k. Default:
True.initial_state (torch.Tensor, optional) – State pool of shape
[pool_size, HV, V, K](K-last / K-contiguous, same layout as the per-batchstateargument). When provided, the kernel gathers directly from the pool usinginitial_state_indicesand writes updates back in-place, eliminating the caller-side gather/scatter overhead. Requires bfloat16 state with K=V=128 (bf16 fast path).initial_state_indices (torch.Tensor, optional) – Per-batch indices of shape
[B](int32 or int64) mapping each batch entry to its slot ininitial_state. Required wheninitial_stateis provided.output_state_indices (torch.Tensor, optional) –
Per-batch indices of shape
[B](int32 or int64) specifying where to write the updated state for each batch entry in the pool. Requiresinitial_stateto be provided. IfNone, the kernel writes the updated state back to the same slot it read from (i.e.initial_state_indices).Padding / inactive sequences: set the index to
-1for any batch entry that should be treated as padding. The two backends handle-1differently:bf16 fast path (bfloat16 state, K=V=128):
-1is redirected toinitial_state[0], which acts as a sacrificial null buffer. The kernel reads from and writes back to slot 0; the output for that batch entry is computed but undefined (caller should not use it). The caller must therefore allocate the pool with an extra leading slot (pool_size = num_real_slots + 1) and keep real slots at indices1..pool_size-1.float32 legacy path (T=1):
-1entries are skipped entirely; neither the state pool nor the output are touched for that batch entry; the output slot is written as zero.
- Returns:
(output, state_or_initial_state)whereoutputhas shape[B, 1, HV, V]and the second element is the updated state (mutated in place).- Return type:
Tuple[torch.Tensor, torch.Tensor]
Notes
Requires SM90+ (Hopper, Blackwell, etc.).
State is always updated in-place; the pool path writes directly into
initial_statememory (no separate scatter step needed).State layout is v-major (K-last):
[B, HV, V, K]. When state is bfloat16 andK = V = 128, the BF16 state kernel is used (T=1 or MTP for T>1); the pool+indices path routes through the MTP kernel.Pool+indices (
initial_state/initial_state_indices) are supported on both the bf16 fast path (K=V=128) and the float32 legacy path (T=1). Both paths support-1padding indices (seeinitial_state_indicesabove for per-backend semantics).Legacy path (float32 state, T=1):
KandVmust each be>= 128, andVmust be a multiple of 8 (the pretranspose tile sizeTILE_V).