flashinfer.gdn_decode.gated_delta_rule_mtp¶
- flashinfer.gdn_decode.gated_delta_rule_mtp(q: Tensor, k: Tensor, v: Tensor, initial_state: Tensor, initial_state_indices: Tensor, A_log: Tensor, a: Tensor, dt_bias: Tensor, b: Tensor, scale: float | None = None, output: Tensor | None = None, intermediate_states_buffer: Tensor | None = None, disable_state_update: bool | None = None, use_qk_l2norm: bool = True) Tuple[Tensor, Tensor]¶
Gated Delta Rule MTP kernel (Multiple Token Processing).
Processes multiple tokens (
T > 1) per call, typically used for speculative decoding verification. Supports intermediate state caching for potential rollback scenarios.- Parameters:
q (torch.Tensor) – Query tensor of shape
[B, T, H, K].k (torch.Tensor) – Key tensor of shape
[B, T, H, K].v (torch.Tensor) – Value tensor of shape
[B, T, HV, V].initial_state (torch.Tensor) – Initial state tensor of shape
[pool_size, HV, V, K](K-last layout). Must be float32 — this standalone MTP entry point does not support the BF16 fast path; for a BF16 K=V=128 state pool, callgated_delta_rule_decode_pretranspose()instead (which dispatches into the BF16 MTP kernel whenT > 1).initial_state_indices (torch.Tensor) – Indices mapping each batch to its initial state, shape
[B].A_log (torch.Tensor) – Log decay parameter of shape
[HV].a (torch.Tensor) – Input-dependent decay of shape
[B, T, HV].dt_bias (torch.Tensor) – Decay bias of shape
[HV].b (torch.Tensor) – Update gate input of shape
[B, T, HV].scale (float, optional) – Scaling factor for queries. If
None, uses1 / sqrt(K).output (torch.Tensor, optional) – Pre-allocated output tensor of shape
[B, T, HV, V].intermediate_states_buffer (torch.Tensor, optional) – Buffer for caching intermediate states, shape
[pool_size, T, HV, V, K]. WhenNone, intermediate states are not cached.disable_state_update (bool, optional) –
If
True, the initial state is not updated. Currently defaults toTrue; pass this argument explicitly to silence the deprecation warning - the default will change toFalsein FlashInfer 0.7.0.Deprecated since version The: implicit default of
Trueis deprecated and will change toFalsein version 0.7.0. Passdisable_state_update=Trueordisable_state_update=Falseexplicitly to silence the warning.use_qk_l2norm (bool) – Whether to apply L2 normalization to q and k. Default:
True.
- Returns:
(output, initial_state)whereoutputhas shape[B, T, HV, V]andinitial_stateis the updated state (unchanged whendisable_state_update=True).- Return type:
Tuple[torch.Tensor, torch.Tensor]
Notes
Requires SM90 (Hopper) architecture.
Supports
T > 1(multiple token processing).State layout is K-last:
[pool_size, HV, V, K].Optimized for speculative decoding verification scenarios.