flashinfer.cute_dsl

CuTe-DSL implementations of selected FlashInfer kernels. These symbols are available only when the nvidia-cutlass-dsl package is installed and the host has a supported NVIDIA GPU; the module guards its imports with is_cute_dsl_available().

Note

A handful of GEMM symbols (grouped_gemm_nt_masked, Sm100BlockScaledPersistentDenseGemmKernel, create_scale_factor_tensor) used to live in flashinfer.cute_dsl and are still re-exported for backwards compatibility, but their canonical home is flashinfer.gemm. New code should import from flashinfer.gemm.

Availability

is_cute_dsl_available()

Return True when the optional CuTe DSL stack is importable.

RMSNorm + FP4 Quantization

rmsnorm_fp4quant(input, weight[, y_fp4, ...])

Fused RMS normalization with FP4 quantization using CuTe-DSL.

add_rmsnorm_fp4quant(input, residual, weight)

Fused Add + RMS normalization + FP4 quantization using CuTe-DSL.

class flashinfer.cute_dsl.RMSNormFP4QuantKernel(dtype: Numeric, H: int, block_size: int, output_swizzled: bool, is_fp16: bool, sm_version: int | None = None, scale_format: str | None = None)

Fused RMSNorm + FP4 Quantization Kernel.

Key optimizations: 1. Half2/BFloat2 SIMD for max-abs computation 2. Branchless scale clamping via fmin_f32 3. Cluster synchronization for large H dimensions 4. Direct 128-bit vectorized global loads

__init__(dtype: Numeric, H: int, block_size: int, output_swizzled: bool, is_fp16: bool, sm_version: int | None = None, scale_format: str | None = None)
kernel(mX: Tensor, mW: Tensor, mY: Tensor, mS: Tensor, mGlobalScale: Tensor, M: Int32, eps: Float32, enable_pdl: Constexpr[bool], tv_layout: Layout, tiler_mn: int | Integer | Tuple[Shape, ...])

Device kernel with cluster synchronization for large H.

mGlobalScale contains the global scale value. The kernel reads it and computes 1/global_scale, which is multiplied with rstd to apply: y = x * rstd * w / global_scale = rmsnorm(x, w) / global_scale

class flashinfer.cute_dsl.AddRMSNormFP4QuantKernel(dtype: Numeric, H: int, block_size: int, output_swizzled: bool, is_fp16: bool, sm_version: int | None = None, scale_format: str | None = None, output_both_sf_layouts: bool = False)

Fused Add + RMSNorm + FP4 Quantization Kernel.

Computes:
  1. residual = input + residual (in-place update)

  2. y = RMSNorm(residual) * weight

  3. quantize y to FP4

The residual tensor is modified in-place. Supports both NVFP4 (block_size=16) and MXFP4 (block_size=32) formats.

__init__(dtype: Numeric, H: int, block_size: int, output_swizzled: bool, is_fp16: bool, sm_version: int | None = None, scale_format: str | None = None, output_both_sf_layouts: bool = False)
kernel(mX: Tensor, mR: Tensor, mW: Tensor, mY: Tensor, mS: Tensor, mS_unswizzled: Tensor, mGlobalScale: Tensor, M: Int32, eps: Float32, enable_pdl: Constexpr[bool], tv_layout: Layout, tiler_mn: int | Integer | Tuple[Shape, ...])

Device kernel with cluster sync and Half2 SIMD.

Performs: 1. h = input + residual (writes h back to mR in-place) 2. y = h * rstd * w / global_scale = rmsnorm(h, w) / global_scale 3. quantizes y to FP4

mGlobalScale contains the global scale value. The kernel reads it and computes 1/global_scale, which is multiplied with rstd to apply: y = h * rstd * w / global_scale = rmsnorm(h, w) / global_scale

Attention Wrappers

CuTe-DSL implementations of the batch attention wrappers.

class flashinfer.cute_dsl.attention.wrappers.batch_mla.BatchMLADecodeCuteDSLWrapper(workspace_buffer: Tensor)

PyTorch-facing wrapper for the modular MLA decode kernel.

Usage:

wrapper = BatchMLADecodeCuteDSLWrapper(workspace_buffer)
wrapper.plan(
    kv_lora_rank=512, qk_rope_head_dim=64, num_heads=128,
    page_size=64, q_dtype=torch.bfloat16,
)
out = wrapper.run(query, kv_cache, block_tables, seq_lens, max_seq_len,
                  softmax_scale=0.125)
__init__(workspace_buffer: Tensor) None

Bind the wrapper to a user-provided workspace buffer.

Parameters:

workspace_buffer (torch.Tensor) – Pre-allocated workspace buffer on the target CUDA device. Must have dtype torch.int8; the size determines the maximum batch this wrapper can handle without re-allocation.

plan(kv_lora_rank: int = 512, qk_rope_head_dim: int = 64, num_heads: int = 128, page_size: int = 1, q_dtype: dtype = torch.bfloat16, out_dtype: dtype | None = None, is_var_seq: bool = True, enable_pdl: bool | None = None, variant: AttentionVariant | None = None) None

Compile (or retrieve cached) MLA decode kernel for the given config.

Parameters:
  • kv_lora_rank (int) – Latent dimension (e.g. 512).

  • qk_rope_head_dim (int) – RoPE dimension (e.g. 64).

  • num_heads (int) – Number of attention heads (typically 128 for DeepSeek-V3).

  • page_size (int) – KV cache page size.

  • q_dtype (torch.dtype) – Query/KV data type (float16 or bfloat16).

  • out_dtype (Optional[torch.dtype]) – Output data type. Defaults to same as q_dtype.

  • is_var_seq (bool) – Whether sequence lengths vary across the batch.

  • enable_pdl (Optional[bool]) – Whether to enable Programmatic Dependent Launch. Auto-detects if None.

  • variant (Optional[AttentionVariant]) – Attention variant (ALiBi, SoftCapping, AttentionWithSink, etc.). None uses standard softmax attention.

run(q: Tensor, kv_cache: Tensor, block_tables: Tensor, seq_lens: Tensor, max_seq_len: int, softmax_scale: float, output_scale: float = 1.0, out: Tensor | None = None) Tensor

Run the MLA decode kernel.

Parameters:
  • q (torch.Tensor) – [B, q_len, H, D_qk] where D_qk = kv_lora_rank + qk_rope_head_dim.

  • kv_cache (torch.Tensor) – [num_pages, page_size, D_total] (3D) or [num_pages, 1, page_size, D_total] (4D).

  • block_tables (torch.Tensor) – [B, max_pages] page table indices.

  • seq_lens (torch.Tensor) – [B] per-request KV sequence lengths.

  • max_seq_len (int) – Maximum sequence length across the batch.

  • softmax_scale (float) – Scale factor for QK^T before softmax.

  • output_scale (float) – Scale factor applied to the output.

  • out (Optional[torch.Tensor]) – Pre-allocated output [B, q_len, H, kv_lora_rank].

Returns:

Output tensor [B, q_len, H, kv_lora_rank].

Return type:

torch.Tensor

class flashinfer.cute_dsl.attention.wrappers.batch_prefill.BatchPrefillCuteDSLWrapper(float_workspace_buffer: Tensor, use_cuda_graph: bool = False)

PyTorch-facing wrapper for the CuTe-DSL ragged-KV batch prefill kernel.

This wrapper exposes a plan + run API compatible with flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper, but compiles a CuTe-DSL kernel under the hood instead of the C++ FA2/FA3 path.

Example

wrapper = BatchPrefillCuteDSLWrapper(workspace_buffer)
wrapper.plan(qo_indptr, kv_indptr,
             num_qo_heads=32, num_kv_heads=8, head_dim_qk=128)
out = wrapper.run(q, k, v)
__init__(float_workspace_buffer: Tensor, use_cuda_graph: bool = False) None

Initialise the wrapper and bind it to a workspace buffer.

Parameters:
  • float_workspace_buffer (torch.Tensor) – Pre-allocated workspace buffer on the target CUDA device. Named for API parity with BatchPrefillWithRaggedKVCacheWrapper; callers typically pass torch.uint8. The CuTe-DSL kernel itself does not consume this buffer, but it is retained so the wrapper can mirror the parent API.

  • use_cuda_graph (bool) – Whether the wrapper will be used inside a CUDA graph capture. Defaults to False.

plan(qo_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo=None, causal=True, sm_scale=1.0, q_data_type=torch.float16, kv_data_type=torch.float16, window_left: int = -1, variant: AttentionVariant | None = None) None

Compile the FMHA prefill kernel for the given configuration.

Parameters:
  • qo_indptr (torch.Tensor) – Cumulative query sequence lengths, shape [batch_size + 1].

  • kv_indptr (torch.Tensor) – Cumulative KV sequence lengths, shape [batch_size + 1].

  • num_qo_heads (int) – Number of query/output heads.

  • num_kv_heads (int) – Number of key/value heads (must divide num_qo_heads).

  • head_dim_qk (int) – Head dimension for queries and keys.

  • head_dim_vo (Optional[int]) – Head dimension for values and output. Must equal head_dim_qk if set.

  • causal (bool) – Whether to apply causal masking.

  • sm_scale (float) – Softmax scale factor (typically 1/sqrt(head_dim)).

  • q_data_type (torch.dtype) – Data type for queries (float16, bfloat16, or float8_e4m3fn).

  • kv_data_type (torch.dtype) – Data type for keys/values.

  • window_left (int) – Sliding window size. -1 disables sliding window.

  • variant (Optional[AttentionVariant]) – Attention variant (ALiBi, RPE, Sigmoid, etc.). None uses standard softmax.

run(q: Tensor, k: Tensor, v: Tensor, out: Tensor | None = None) Tensor

Run the prefill attention computation.

Parameters:
  • q (torch.Tensor) – The query tensor with shape [total_q_len, num_heads, head_dim].

  • k (torch.Tensor) – The key tensor with shape [total_kv_len, num_heads, head_dim].

  • v (torch.Tensor) – The value tensor with shape [total_kv_len, num_heads, head_dim].

  • out (Optional[torch.Tensor], optional) – The output tensor. If None, a new tensor will be created.

Returns:

The output tensor with shape [total_q_len, num_heads, head_dim].

Return type:

torch.Tensor