flashinfer.gdn_decode.gated_delta_rule_decode_pretranspose

flashinfer.gdn_decode.gated_delta_rule_decode_pretranspose(q: Tensor, k: Tensor, v: Tensor, state: Tensor | None, A_log: Tensor, a: Tensor, dt_bias: Tensor, b: Tensor, scale: float | None = None, output: Tensor | None = None, use_qk_l2norm: bool = True, initial_state: Tensor | None = None, initial_state_indices: Tensor | None = None, output_state_indices: Tensor | None = None) Tuple[Tensor, Tensor]

Gated Delta Rule Decode kernel for single-token generation.

Implements the decode phase of gated delta rule linear attention, processing one token at a time and updating the recurrent state.

Parameters:
  • q (torch.Tensor) – Current query of shape [B, 1, H, K]. Must be float16/bfloat16.

  • k (torch.Tensor) – Current key of shape [B, 1, H, K]. Must be float16/bfloat16.

  • v (torch.Tensor) – Current value of shape [B, 1, HV, V]. Must be float16/bfloat16.

  • state (torch.Tensor, optional) – Current state of shape [B, HV, V, K] (v-major / K-last layout). Float32: legacy kernel (T=1 only). Bfloat16: BF16 state backend (T=1 or MTP for T>1) when K=V=128. Updated in-place. Pass None when using initial_state / initial_state_indices instead.

  • A_log (torch.Tensor) – Log decay parameter of shape [HV]. Must be float32.

  • a (torch.Tensor) – Input-dependent decay of shape [B, 1, HV]. Must be float16/bfloat16.

  • dt_bias (torch.Tensor) – Decay bias of shape [HV]. Must be bfloat16 or float32.

  • b (torch.Tensor) – Update gate (beta) input of shape [B, 1, HV]. Must be float16/bfloat16.

  • scale (float, optional) – Scale factor for queries. If None, defaults to 1 / sqrt(K).

  • output (torch.Tensor, optional) – Pre-allocated output tensor of shape [B, 1, HV, V]. Allocated automatically when None.

  • use_qk_l2norm (bool) – Whether to apply L2 normalization to q and k. Default: True.

  • initial_state (torch.Tensor, optional) – State pool of shape [pool_size, HV, V, K] (K-last / K-contiguous, same layout as the per-batch state argument). When provided, the kernel gathers directly from the pool using initial_state_indices and writes updates back in-place, eliminating the caller-side gather/scatter overhead. Requires bfloat16 state with K=V=128 (bf16 fast path).

  • initial_state_indices (torch.Tensor, optional) – Per-batch indices of shape [B] (int32 or int64) mapping each batch entry to its slot in initial_state. Required when initial_state is provided.

  • output_state_indices (torch.Tensor, optional) –

    Per-batch indices of shape [B] (int32 or int64) specifying where to write the updated state for each batch entry in the pool. Requires initial_state to be provided. If None, the kernel writes the updated state back to the same slot it read from (i.e. initial_state_indices).

    Padding / inactive sequences: set the index to -1 for any batch entry that should be treated as padding. The two backends handle -1 differently:

    • bf16 fast path (bfloat16 state, K=V=128): -1 is redirected to initial_state[0], which acts as a sacrificial null buffer. The kernel reads from and writes back to slot 0; the output for that batch entry is computed but undefined (caller should not use it). The caller must therefore allocate the pool with an extra leading slot (pool_size = num_real_slots + 1) and keep real slots at indices 1..pool_size-1.

    • float32 legacy path (T=1): -1 entries are skipped entirely; neither the state pool nor the output are touched for that batch entry; the output slot is written as zero.

Returns:

(output, state_or_initial_state) where output has shape [B, 1, HV, V] and the second element is the updated state (mutated in place).

Return type:

Tuple[torch.Tensor, torch.Tensor]

Notes

  • Requires SM90+ (Hopper, Blackwell, etc.).

  • State is always updated in-place; the pool path writes directly into initial_state memory (no separate scatter step needed).

  • State layout is v-major (K-last): [B, HV, V, K]. When state is bfloat16 and K = V = 128, the BF16 state kernel is used (T=1 or MTP for T>1); the pool+indices path routes through the MTP kernel.

  • Pool+indices (initial_state / initial_state_indices) are supported on both the bf16 fast path (K=V=128) and the float32 legacy path (T=1). Both paths support -1 padding indices (see initial_state_indices above for per-backend semantics).

  • Legacy path (float32 state, T=1): K and V must each be >= 128, and V must be a multiple of 8 (the pretranspose tile size TILE_V).