flashinfer.gdn_decode.gated_delta_rule_mtp

flashinfer.gdn_decode.gated_delta_rule_mtp(q: Tensor, k: Tensor, v: Tensor, initial_state: Tensor, initial_state_indices: Tensor, A_log: Tensor, a: Tensor, dt_bias: Tensor, b: Tensor, scale: float | None = None, output: Tensor | None = None, intermediate_states_buffer: Tensor | None = None, disable_state_update: bool | None = None, use_qk_l2norm: bool = True) Tuple[Tensor, Tensor]

Gated Delta Rule MTP kernel (Multiple Token Processing).

Processes multiple tokens (T > 1) per call, typically used for speculative decoding verification. Supports intermediate state caching for potential rollback scenarios.

Parameters:
  • q (torch.Tensor) – Query tensor of shape [B, T, H, K].

  • k (torch.Tensor) – Key tensor of shape [B, T, H, K].

  • v (torch.Tensor) – Value tensor of shape [B, T, HV, V].

  • initial_state (torch.Tensor) – Initial state tensor of shape [pool_size, HV, V, K] (K-last layout). Must be float32 — this standalone MTP entry point does not support the BF16 fast path; for a BF16 K=V=128 state pool, call gated_delta_rule_decode_pretranspose() instead (which dispatches into the BF16 MTP kernel when T > 1).

  • initial_state_indices (torch.Tensor) – Indices mapping each batch to its initial state, shape [B].

  • A_log (torch.Tensor) – Log decay parameter of shape [HV].

  • a (torch.Tensor) – Input-dependent decay of shape [B, T, HV].

  • dt_bias (torch.Tensor) – Decay bias of shape [HV].

  • b (torch.Tensor) – Update gate input of shape [B, T, HV].

  • scale (float, optional) – Scaling factor for queries. If None, uses 1 / sqrt(K).

  • output (torch.Tensor, optional) – Pre-allocated output tensor of shape [B, T, HV, V].

  • intermediate_states_buffer (torch.Tensor, optional) – Buffer for caching intermediate states, shape [pool_size, T, HV, V, K]. When None, intermediate states are not cached.

  • disable_state_update (bool, optional) –

    If True, the initial state is not updated. Currently defaults to True; pass this argument explicitly to silence the deprecation warning - the default will change to False in FlashInfer 0.7.0.

    Deprecated since version The: implicit default of True is deprecated and will change to False in version 0.7.0. Pass disable_state_update=True or disable_state_update=False explicitly to silence the warning.

  • use_qk_l2norm (bool) – Whether to apply L2 normalization to q and k. Default: True.

Returns:

(output, initial_state) where output has shape [B, T, HV, V] and initial_state is the updated state (unchanged when disable_state_update=True).

Return type:

Tuple[torch.Tensor, torch.Tensor]

Notes

  • Requires SM90 (Hopper) architecture.

  • Supports T > 1 (multiple token processing).

  • State layout is K-last: [pool_size, HV, V, K].

  • Optimized for speculative decoding verification scenarios.