flashinfer.mamba.checkpointing_ssu¶
- flashinfer.mamba.checkpointing_ssu(state: Tensor, old_x: Tensor, old_B: Tensor, old_dt: Tensor, old_cumAdt: Tensor, cache_buf_idx: Tensor, prev_num_accepted_tokens: Tensor, x: Tensor, dt: Tensor, A: Tensor, B: Tensor, C: Tensor, out: Tensor, D: Tensor | None = None, z: Tensor | None = None, dt_bias: Tensor | None = None, dt_softplus: bool = False, state_batch_indices: Tensor | None = None, pad_slot_id: int = -1, state_scale: Tensor | None = None, rand_seed: Tensor | None = None, philox_rounds: int = 10, d_split: int | None = None, cu_seqlens: Tensor | None = None, max_seqlen: int | None = None, enable_pdl: bool = False) Tensor¶
Checkpointing SSU with MTP replay using matmul-based parallel token processing.
- Parameters:
state (torch.Tensor) – SSM state, shape (state_cache_size, nheads, dim, dstate). Updated in-place.
old_x (torch.Tensor) – Cached x from previous step, shape (state_cache_size, T, nheads, dim). Single-buffered.
old_B (torch.Tensor) – Cached B, shape (state_cache_size, 2, T, ngroups, dstate). Double-buffered.
old_dt (torch.Tensor) – Cached processed dt, shape (state_cache_size, 2, nheads, T). Double-buffered, f32.
old_cumAdt (torch.Tensor) – Cached cumulative A*dt, shape (state_cache_size, 2, nheads, T). Double-buffered, f32.
cache_buf_idx (torch.Tensor) – Which buffer to read (0 or 1), shape (state_cache_size,), int32.
prev_num_accepted_tokens (torch.Tensor) – Number of old tokens to replay, shape (state_cache_size,), int32.
x (torch.Tensor) – New token inputs, shape (batch, T, nheads, dim).
dt (torch.Tensor) – Delta time, shape (batch, T, nheads, dim) with tie_hdim (stride[-1]=0). Accepted in native dtype (e.g. bf16) — converted to f32 internally.
A (torch.Tensor) – Decay rate, shape (nheads, dim, dstate) with tie_hdim.
B (torch.Tensor) – Input projection, shape (batch, T, ngroups, dstate).
C (torch.Tensor) – Output projection, shape (batch, T, ngroups, dstate).
out (torch.Tensor) – Preallocated output, shape (batch, T, nheads, dim).
D (Optional[torch.Tensor]) – Skip connection, shape (nheads, dim).
z (Optional[torch.Tensor]) – Gate, shape (batch, T, nheads, dim).
dt_bias (Optional[torch.Tensor]) – Bias added to dt, shape (nheads, dim) with tie_hdim.
dt_softplus (bool) – Whether to apply softplus to dt.
state_batch_indices (Optional[torch.Tensor]) – Maps batch index to cache slot, shape (batch,), int32 | int64.
pad_slot_id (int) – Sentinel value for padded entries.
state_scale (Optional[torch.Tensor]) – Block-scale decode factors for quantized state, shape (state_cache_size, nheads, dim), f32.
rand_seed (Optional[torch.Tensor]) – Single-element int64 CUDA tensor for stochastic rounding seed.
philox_rounds (int) – Philox PRNG rounds for stochastic rounding (default 10).
d_split (Optional[int]) – Per-head DIM split factor. This is only exposed for benchmarking. Do not use it cause it will make things slow.
cu_seqlens (Optional[torch.Tensor]) – Cumulative sequence lengths with shape
(N + 1,), dtypetorch.int32, on the same CUDA device asx(the kernel asserts both). When provided, the new-token inputs (x,dt,B,C,out, optionallyz) are interpreted in varlen layout where tokens are packed along the time axis with batch fixed to 1 — i.e.xis 4-D with shape(1, total_tokens, nheads, dim)— instead of the default(batch, T, ...)layout.max_seqlen (Optional[int]) – Maximum sequence length present in
cu_seqlens, used by the kernel to size its per-sequence work tiles. Only meaningful in varlen mode (cu_seqlens is not None); falls back tomax_windowwhen omitted (wider smem than strictly needed but always safe). Must beNonein non-varlen mode (the JIT key is taken fromx.size(1)).enable_pdl (bool) – When True the kernel is launched with cudaLaunchAttributeProgrammaticStreamSerialization, enabling the in-kernel griddepcontrol.{wait,launch_dependents} PTX to gate on the upstream (e.g. conv1d) and signal the downstream kernel. Caller’s responsibility: upstream/downstream kernels must also be PDL-paired for the wait/signal to have effect. Defaults to False.
- Returns:
out – Output tensor, shape (batch, T, nheads, dim).
- Return type:
torch.Tensor
Notes
In-place updates. The custom op declares
mutates_args = ("state", "out", "old_x", "old_B", "old_dt", "old_cumAdt", "state_scale")— the fourold_*cache tensors are double-buffered and the kernel writes the current step’s x / B / dt / cumulative-A·dt back into the slot selected bycache_buf_idxso the next call can replay them.state_scaleis also written whenstateis quantized (int8 / fp8_e4m3fn): the kernel computes new per-block decode scales and stores them here for the caller to dequantize against on read-back.