flashinfer.gdn_decode.gated_delta_rule_mtp

flashinfer.gdn_decode.gated_delta_rule_mtp(q: Tensor, k: Tensor, v: Tensor, initial_state: Tensor, initial_state_indices: Tensor, A_log: Tensor, a: Tensor, dt_bias: Tensor, b: Tensor, scale: float | None = None, output: Tensor | None = None, intermediate_states_buffer: Tensor | None = None, ssm_state_indices: Tensor | None = None, disable_state_update: bool | None = None, use_qk_l2norm: bool = True, output_state_indices: Tensor | None = None) Tuple[Tensor, Tensor]

Gated Delta Rule MTP kernel (Multiple Token Processing).

Processes multiple tokens (T > 1) per call, typically used for speculative decoding verification. Supports intermediate state caching for potential rollback scenarios.

Parameters:
  • q (torch.Tensor) – Query tensor of shape [B, T, H, K].

  • k (torch.Tensor) – Key tensor of shape [B, T, H, K].

  • v (torch.Tensor) – Value tensor of shape [B, T, HV, V].

  • initial_state (torch.Tensor) – Initial state pool of shape [pool_size, HV, V, K] (K-last layout). Must be float32 — this standalone MTP entry point does not support the BF16 fast path; for a BF16 K=V=128 state pool, call gated_delta_rule_decode_pretranspose() instead (which dispatches into the BF16 MTP kernel when T > 1). When contiguous the kernel reads/writes the pool in-place via the free 4D→3D reshape view; a non-contiguous pool is dispatched through the native 4D use_pool_indexing=True path and the kernel writes the strided pool in place without densification.

  • initial_state_indices (torch.Tensor) – Read indices mapping each batch to its slot in initial_state, shape [B]. Negative entries are treated as padding — the kernel skips both the read and the writeback for that batch and the output slot is left as the caller-allocated value (zero when output is None).

  • A_log (torch.Tensor) – Log decay parameter of shape [HV].

  • a (torch.Tensor) – Input-dependent decay of shape [B, T, HV].

  • dt_bias (torch.Tensor) – Decay bias of shape [HV].

  • b (torch.Tensor) – Update gate input of shape [B, T, HV].

  • scale (float, optional) – Scaling factor for queries. If None, uses 1 / sqrt(K).

  • output (torch.Tensor, optional) – Pre-allocated output tensor of shape [B, T, HV, V].

  • intermediate_states_buffer (torch.Tensor, optional) – Buffer for caching intermediate states, shape [B, T, HV, V, K] (first dim is indexed per-batch, not per-pool-slot — buffer must be at least B rows and contiguous; must be float32 when provided). When None, intermediate states are not cached. Mutually exclusive with ssm_state_indices.

  • ssm_state_indices (torch.Tensor, optional) – Per-token pool scatter indices of shape [B, T] and dtype torch.int32. When provided, the kernel writes each intermediate hidden state h_{t+1} directly to initial_state[ssm_state_indices[i, t]] instead of accumulating into a dense intermediate_states_buffer. Useful for FLA-style speculative-decoding flows where each draft token needs its own pool slot. Constraints: T >= 2, disable_state_update=False, mutually exclusive with intermediate_states_buffer. Default: None.

  • disable_state_update (bool, optional) –

    If True, the initial state is not updated. Currently defaults to True; pass this argument explicitly to silence the deprecation warning - the default will change to False in FlashInfer 0.7.0.

    Deprecated since version The: implicit default of True is deprecated and will change to False in version 0.7.0. Pass disable_state_update=True or disable_state_update=False explicitly to silence the warning.

  • use_qk_l2norm (bool) – Whether to apply L2 normalization to q and k. Default: True.

  • output_state_indices (torch.Tensor, optional) – Write indices of shape [B] (int32 or int64) specifying the destination pool slot for each batch’s updated state. Defaults to initial_state_indices (read and write target the same slot). Negative entries skip the writeback for that batch (the read still runs).

Returns:

(output, initial_state) where output has shape [B, T, HV, V] and initial_state is the updated state (unchanged when disable_state_update=True).

Return type:

Tuple[torch.Tensor, torch.Tensor]

Notes

  • Requires SM90 (Hopper) architecture.

  • Supports T > 1 (multiple token processing).

  • State layout is K-last: [pool_size, HV, V, K].

  • Optimized for speculative decoding verification scenarios.