flashinfer.gdn_decode.gated_delta_rule_mtp¶
- flashinfer.gdn_decode.gated_delta_rule_mtp(q: Tensor, k: Tensor, v: Tensor, initial_state: Tensor, initial_state_indices: Tensor, A_log: Tensor, a: Tensor, dt_bias: Tensor, b: Tensor, scale: float | None = None, output: Tensor | None = None, intermediate_states_buffer: Tensor | None = None, ssm_state_indices: Tensor | None = None, disable_state_update: bool | None = None, use_qk_l2norm: bool = True, output_state_indices: Tensor | None = None) Tuple[Tensor, Tensor]¶
Gated Delta Rule MTP kernel (Multiple Token Processing).
Processes multiple tokens (
T > 1) per call, typically used for speculative decoding verification. Supports intermediate state caching for potential rollback scenarios.- Parameters:
q (torch.Tensor) – Query tensor of shape
[B, T, H, K].k (torch.Tensor) – Key tensor of shape
[B, T, H, K].v (torch.Tensor) – Value tensor of shape
[B, T, HV, V].initial_state (torch.Tensor) – Initial state pool of shape
[pool_size, HV, V, K](K-last layout). Must be float32 — this standalone MTP entry point does not support the BF16 fast path; for a BF16 K=V=128 state pool, callgated_delta_rule_decode_pretranspose()instead (which dispatches into the BF16 MTP kernel whenT > 1). When contiguous the kernel reads/writes the pool in-place via the free 4D→3D reshape view; a non-contiguous pool is dispatched through the native 4Duse_pool_indexing=Truepath and the kernel writes the strided pool in place without densification.initial_state_indices (torch.Tensor) – Read indices mapping each batch to its slot in
initial_state, shape[B]. Negative entries are treated as padding — the kernel skips both the read and the writeback for that batch and the output slot is left as the caller-allocated value (zero whenoutputisNone).A_log (torch.Tensor) – Log decay parameter of shape
[HV].a (torch.Tensor) – Input-dependent decay of shape
[B, T, HV].dt_bias (torch.Tensor) – Decay bias of shape
[HV].b (torch.Tensor) – Update gate input of shape
[B, T, HV].scale (float, optional) – Scaling factor for queries. If
None, uses1 / sqrt(K).output (torch.Tensor, optional) – Pre-allocated output tensor of shape
[B, T, HV, V].intermediate_states_buffer (torch.Tensor, optional) – Buffer for caching intermediate states, shape
[B, T, HV, V, K](first dim is indexed per-batch, not per-pool-slot — buffer must be at leastBrows and contiguous; must be float32 when provided). WhenNone, intermediate states are not cached. Mutually exclusive withssm_state_indices.ssm_state_indices (torch.Tensor, optional) – Per-token pool scatter indices of shape
[B, T]and dtypetorch.int32. When provided, the kernel writes each intermediate hidden stateh_{t+1}directly toinitial_state[ssm_state_indices[i, t]]instead of accumulating into a denseintermediate_states_buffer. Useful for FLA-style speculative-decoding flows where each draft token needs its own pool slot. Constraints:T >= 2,disable_state_update=False, mutually exclusive withintermediate_states_buffer. Default:None.disable_state_update (bool, optional) –
If
True, the initial state is not updated. Currently defaults toTrue; pass this argument explicitly to silence the deprecation warning - the default will change toFalsein FlashInfer 0.7.0.Deprecated since version The: implicit default of
Trueis deprecated and will change toFalsein version 0.7.0. Passdisable_state_update=Trueordisable_state_update=Falseexplicitly to silence the warning.use_qk_l2norm (bool) – Whether to apply L2 normalization to q and k. Default:
True.output_state_indices (torch.Tensor, optional) – Write indices of shape
[B](int32 or int64) specifying the destination pool slot for each batch’s updated state. Defaults toinitial_state_indices(read and write target the same slot). Negative entries skip the writeback for that batch (the read still runs).
- Returns:
(output, initial_state)whereoutputhas shape[B, T, HV, V]andinitial_stateis the updated state (unchanged whendisable_state_update=True).- Return type:
Tuple[torch.Tensor, torch.Tensor]
Notes
Requires SM90 (Hopper) architecture.
Supports
T > 1(multiple token processing).State layout is K-last:
[pool_size, HV, V, K].Optimized for speculative decoding verification scenarios.