flashinfer.mamba.selective_state_update¶
- flashinfer.mamba.selective_state_update(state: Tensor, x: Tensor, dt: Tensor, A: Tensor, B: Tensor, C: Tensor, D: Tensor, z: Tensor | None = None, dt_bias: Tensor | None = None, dt_softplus: bool = False, state_batch_indices: Tensor | None = None, pad_slot_id: int = -1, state_scale: Tensor | None = None, out: Tensor | None = None, disable_state_update: bool = False, intermediate_states_buffer: Tensor | None = None, intermediate_state_indices: Tensor | None = None, intermediate_state_scales: Tensor | None = None, rand_seed: Tensor | None = None, philox_rounds: int = 10, cache_steps: int = 0, algorithm: str = 'auto', dst_state_batch_indices: Tensor | None = None, cu_seqlens: Tensor | None = None, num_accepted_tokens: Tensor | None = None) Tensor¶
Selective state update operation for Mamba layers (the generation phase).
- Parameters:
state (torch.Tensor) – State tensor with shape (state_cache_size, dim, dstate) or (state_cache_size, nheads, dim, dstate)
x (torch.Tensor) – Input tensor with shape (batch, dim) or (batch, nheads, dim) for single-token, (batch, T, nheads, dim) for multi-token, or (total_tokens, nheads, dim) for varlen multi-token (with cu_seqlens)
dt (torch.Tensor) – Delta time tensor, same layout as x
A (torch.Tensor) – A matrix with shape (dim, dstate) or (nheads, dim, dstate)
B (torch.Tensor) – B matrix with shape (batch, dstate) or (batch, ngroups, dstate) for single-token, (batch, T, ngroups, dstate) for multi-token, or (total_tokens, ngroups, dstate) for varlen multi-token
C (torch.Tensor) – C matrix, same layout as B
D (torch.Tensor) – D vector with shape (dim,) or (nheads, dim)
z (Optional[torch.Tensor]) – Optional z tensor, same layout as x
dt_bias (Optional[torch.Tensor]) – Optional dt bias with shape (dim,) or (nheads, dim)
dt_softplus (bool) – Whether to apply softplus to dt
state_batch_indices (Optional[torch.Tensor]) – Batch indices for state cache reading. Shape (batch,) or (N, max_seqlen). For speculative decoding with num_accepted_tokens, must be 2D.
dst_state_batch_indices (Optional[torch.Tensor]) – Destination indices for state cache writing. Shape (batch,) or (N, max_seqlen). When provided, state is read from state_batch_indices and written to dst_state_batch_indices (enables separate read/write state slots).
pad_slot_id (int) – Sentinel value for padded entries in state_batch_indices
state_scale (Optional[torch.Tensor]) – Optional float32 scale tensor with shape (state_cache_size, nheads, dim) for int16 state quantization with block scaling. Listed in the custom op’s
mutates_args: whenstateis quantized (int16), the kernel writes the new per-block scales here in place — the caller dequantizesstateagainst this tensor on read-back (mirrors theintermediate_state_scalescontract below).out (Optional[torch.Tensor]) – Optional output tensor (same shape as x)
disable_state_update (bool) – If True, skip updating the state tensor (useful for speculative decoding verification)
intermediate_states_buffer (Optional[torch.Tensor]) – Optional buffer for caching intermediate states during speculative decoding with shape (batch, cache_steps, nheads, dim, dstate). Also listed in
mutates_args— the kernel writes intermediate states into it in place.intermediate_state_indices (Optional[torch.Tensor]) – Optional indices mapping batch elements to intermediate state buffer positions with shape (batch,)
intermediate_state_scales (Optional[torch.Tensor]) – Optional per-block float32 scale tensor matching
intermediate_states_buffer. When provided alongside an int16intermediate_states_buffer, the kernel writes the computed scales into this tensor (it is listed in the custom op’smutates_args), and the caller is responsible for dequantizing the intermediate states using these scales when reading them back. Mirrors thestate_scalelayout but for the speculative-decoding intermediate buffer.rand_seed (Optional[torch.Tensor]) – Optional single-element int64 CUDA tensor for stochastic rounding seed.
philox_rounds (int) – Number of Philox-4x32 PRNG rounds for stochastic rounding (default 10).
cache_steps (int) – Number of steps/tokens to cache for speculative decoding. For varlen mode (cu_seqlens provided), this specifies max_seqlen.
cu_seqlens (Optional[torch.Tensor]) – Cumulative sequence lengths with shape (N + 1,), integer dtype (the JIT specializes on the actual dtype, so int32 or int64 is fine; int32 is the default when omitted). When provided, inputs are in packed-token varlen format:
x/dtare 3-D(total_tokens, nheads, dim),B/Care 3-D(total_tokens, ngroups, dstate), with sequence boundaries given bycu_seqlens.num_accepted_tokens (Optional[torch.Tensor]) – Number of accepted tokens per sequence with shape (N,). Determines which state to read as initial state for each sequence.
algorithm (str) – Algorithm to use: “auto”, “simple”, “vertical”, “horizontal”
- Returns:
output – Output tensor with same shape as x
- Return type:
torch.Tensor