flashinfer.mamba.checkpointing_ssu

flashinfer.mamba.checkpointing_ssu(state: Tensor, old_x: Tensor, old_B: Tensor, old_dt: Tensor, old_cumAdt: Tensor, cache_buf_idx: Tensor, prev_num_accepted_tokens: Tensor, x: Tensor, dt: Tensor, A: Tensor, B: Tensor, C: Tensor, out: Tensor, D: Tensor | None = None, z: Tensor | None = None, dt_bias: Tensor | None = None, dt_softplus: bool = False, state_batch_indices: Tensor | None = None, pad_slot_id: int = -1, state_scale: Tensor | None = None, rand_seed: Tensor | None = None, philox_rounds: int = 10, d_split: int | None = None, cu_seqlens: Tensor | None = None, max_seqlen: int | None = None, enable_pdl: bool = False) Tensor

Checkpointing SSU with MTP replay using matmul-based parallel token processing.

Parameters:
  • state (torch.Tensor) – SSM state, shape (state_cache_size, nheads, dim, dstate). Updated in-place.

  • old_x (torch.Tensor) – Cached x from previous step, shape (state_cache_size, T, nheads, dim). Single-buffered.

  • old_B (torch.Tensor) – Cached B, shape (state_cache_size, 2, T, ngroups, dstate). Double-buffered.

  • old_dt (torch.Tensor) – Cached processed dt, shape (state_cache_size, 2, nheads, T). Double-buffered, f32.

  • old_cumAdt (torch.Tensor) – Cached cumulative A*dt, shape (state_cache_size, 2, nheads, T). Double-buffered, f32.

  • cache_buf_idx (torch.Tensor) – Which buffer to read (0 or 1), shape (state_cache_size,), int32.

  • prev_num_accepted_tokens (torch.Tensor) – Number of old tokens to replay, shape (state_cache_size,), int32.

  • x (torch.Tensor) – New token inputs, shape (batch, T, nheads, dim).

  • dt (torch.Tensor) – Delta time, shape (batch, T, nheads, dim) with tie_hdim (stride[-1]=0). Accepted in native dtype (e.g. bf16) — converted to f32 internally.

  • A (torch.Tensor) – Decay rate, shape (nheads, dim, dstate) with tie_hdim.

  • B (torch.Tensor) – Input projection, shape (batch, T, ngroups, dstate).

  • C (torch.Tensor) – Output projection, shape (batch, T, ngroups, dstate).

  • out (torch.Tensor) – Preallocated output, shape (batch, T, nheads, dim).

  • D (Optional[torch.Tensor]) – Skip connection, shape (nheads, dim).

  • z (Optional[torch.Tensor]) – Gate, shape (batch, T, nheads, dim).

  • dt_bias (Optional[torch.Tensor]) – Bias added to dt, shape (nheads, dim) with tie_hdim.

  • dt_softplus (bool) – Whether to apply softplus to dt.

  • state_batch_indices (Optional[torch.Tensor]) – Maps batch index to cache slot, shape (batch,), int32 | int64.

  • pad_slot_id (int) – Sentinel value for padded entries.

  • state_scale (Optional[torch.Tensor]) – Block-scale decode factors for quantized state, shape (state_cache_size, nheads, dim), f32.

  • rand_seed (Optional[torch.Tensor]) – Single-element int64 CUDA tensor for stochastic rounding seed.

  • philox_rounds (int) – Philox PRNG rounds for stochastic rounding (default 10).

  • d_split (Optional[int]) – Per-head DIM split factor. This is only exposed for benchmarking. Do not use it cause it will make things slow.

  • cu_seqlens (Optional[torch.Tensor]) – Cumulative sequence lengths with shape (N + 1,), dtype torch.int32, on the same CUDA device as x (the kernel asserts both). When provided, the new-token inputs (x, dt, B, C, out, optionally z) are interpreted in varlen layout where tokens are packed along the time axis with batch fixed to 1 — i.e. x is 4-D with shape (1, total_tokens, nheads, dim) — instead of the default (batch, T, ...) layout.

  • max_seqlen (Optional[int]) – Maximum sequence length present in cu_seqlens, used by the kernel to size its per-sequence work tiles. Only meaningful in varlen mode (cu_seqlens is not None); falls back to max_window when omitted (wider smem than strictly needed but always safe). Must be None in non-varlen mode (the JIT key is taken from x.size(1)).

  • enable_pdl (bool) – When True the kernel is launched with cudaLaunchAttributeProgrammaticStreamSerialization, enabling the in-kernel griddepcontrol.{wait,launch_dependents} PTX to gate on the upstream (e.g. conv1d) and signal the downstream kernel. Caller’s responsibility: upstream/downstream kernels must also be PDL-paired for the wait/signal to have effect. Defaults to False.

Returns:

out – Output tensor, shape (batch, T, nheads, dim).

Return type:

torch.Tensor

Notes

In-place updates. The custom op declares mutates_args = ("state", "out", "old_x", "old_B", "old_dt", "old_cumAdt", "state_scale") — the four old_* cache tensors are double-buffered and the kernel writes the current step’s x / B / dt / cumulative-A·dt back into the slot selected by cache_buf_idx so the next call can replay them. state_scale is also written when state is quantized (int8 / fp8_e4m3fn): the kernel computes new per-block decode scales and stores them here for the caller to dequantize against on read-back.