flashinfer.mamba.selective_state_update

flashinfer.mamba.selective_state_update(state: Tensor, x: Tensor, dt: Tensor, A: Tensor, B: Tensor, C: Tensor, D: Tensor, z: Tensor | None = None, dt_bias: Tensor | None = None, dt_softplus: bool = False, state_batch_indices: Tensor | None = None, pad_slot_id: int = -1, state_scale: Tensor | None = None, out: Tensor | None = None, disable_state_update: bool = False, intermediate_states_buffer: Tensor | None = None, intermediate_state_indices: Tensor | None = None, intermediate_state_scales: Tensor | None = None, rand_seed: Tensor | None = None, philox_rounds: int = 10, cache_steps: int = 0, algorithm: str = 'auto', dst_state_batch_indices: Tensor | None = None, cu_seqlens: Tensor | None = None, num_accepted_tokens: Tensor | None = None) Tensor

Selective state update operation for Mamba layers (the generation phase).

Parameters:
  • state (torch.Tensor) – State tensor with shape (state_cache_size, dim, dstate) or (state_cache_size, nheads, dim, dstate)

  • x (torch.Tensor) – Input tensor with shape (batch, dim) or (batch, nheads, dim) for single-token, (batch, T, nheads, dim) for multi-token, or (total_tokens, nheads, dim) for varlen multi-token (with cu_seqlens)

  • dt (torch.Tensor) – Delta time tensor, same layout as x

  • A (torch.Tensor) – A matrix with shape (dim, dstate) or (nheads, dim, dstate)

  • B (torch.Tensor) – B matrix with shape (batch, dstate) or (batch, ngroups, dstate) for single-token, (batch, T, ngroups, dstate) for multi-token, or (total_tokens, ngroups, dstate) for varlen multi-token

  • C (torch.Tensor) – C matrix, same layout as B

  • D (torch.Tensor) – D vector with shape (dim,) or (nheads, dim)

  • z (Optional[torch.Tensor]) – Optional z tensor, same layout as x

  • dt_bias (Optional[torch.Tensor]) – Optional dt bias with shape (dim,) or (nheads, dim)

  • dt_softplus (bool) – Whether to apply softplus to dt

  • state_batch_indices (Optional[torch.Tensor]) – Batch indices for state cache reading. Shape (batch,) or (N, max_seqlen). For speculative decoding with num_accepted_tokens, must be 2D.

  • dst_state_batch_indices (Optional[torch.Tensor]) – Destination indices for state cache writing. Shape (batch,) or (N, max_seqlen). When provided, state is read from state_batch_indices and written to dst_state_batch_indices (enables separate read/write state slots).

  • pad_slot_id (int) – Sentinel value for padded entries in state_batch_indices

  • state_scale (Optional[torch.Tensor]) – Optional float32 scale tensor with shape (state_cache_size, nheads, dim) for int16 state quantization with block scaling. Listed in the custom op’s mutates_args: when state is quantized (int16), the kernel writes the new per-block scales here in place — the caller dequantizes state against this tensor on read-back (mirrors the intermediate_state_scales contract below).

  • out (Optional[torch.Tensor]) – Optional output tensor (same shape as x)

  • disable_state_update (bool) – If True, skip updating the state tensor (useful for speculative decoding verification)

  • intermediate_states_buffer (Optional[torch.Tensor]) – Optional buffer for caching intermediate states during speculative decoding with shape (batch, cache_steps, nheads, dim, dstate). Also listed in mutates_args — the kernel writes intermediate states into it in place.

  • intermediate_state_indices (Optional[torch.Tensor]) – Optional indices mapping batch elements to intermediate state buffer positions with shape (batch,)

  • intermediate_state_scales (Optional[torch.Tensor]) – Optional per-block float32 scale tensor matching intermediate_states_buffer. When provided alongside an int16 intermediate_states_buffer, the kernel writes the computed scales into this tensor (it is listed in the custom op’s mutates_args), and the caller is responsible for dequantizing the intermediate states using these scales when reading them back. Mirrors the state_scale layout but for the speculative-decoding intermediate buffer.

  • rand_seed (Optional[torch.Tensor]) – Optional single-element int64 CUDA tensor for stochastic rounding seed.

  • philox_rounds (int) – Number of Philox-4x32 PRNG rounds for stochastic rounding (default 10).

  • cache_steps (int) – Number of steps/tokens to cache for speculative decoding. For varlen mode (cu_seqlens provided), this specifies max_seqlen.

  • cu_seqlens (Optional[torch.Tensor]) – Cumulative sequence lengths with shape (N + 1,), integer dtype (the JIT specializes on the actual dtype, so int32 or int64 is fine; int32 is the default when omitted). When provided, inputs are in packed-token varlen format: x / dt are 3-D (total_tokens, nheads, dim), B / C are 3-D (total_tokens, ngroups, dstate), with sequence boundaries given by cu_seqlens.

  • num_accepted_tokens (Optional[torch.Tensor]) – Number of accepted tokens per sequence with shape (N,). Determines which state to read as initial state for each sequence.

  • algorithm (str) – Algorithm to use: “auto”, “simple”, “vertical”, “horizontal”

Returns:

output – Output tensor with same shape as x

Return type:

torch.Tensor