flashinfer.cute_dsl¶
CuTe-DSL implementations of selected FlashInfer kernels. These symbols are
available only when the nvidia-cutlass-dsl package is installed and the
host has a supported NVIDIA GPU; the module guards its imports with
is_cute_dsl_available().
Note
A handful of GEMM symbols (grouped_gemm_nt_masked,
Sm100BlockScaledPersistentDenseGemmKernel,
create_scale_factor_tensor) used to live in flashinfer.cute_dsl and
are still re-exported for backwards compatibility, but their canonical
home is flashinfer.gemm. New code should import from flashinfer.gemm.
Availability¶
Return |
RMSNorm + FP4 Quantization¶
|
Fused RMS normalization with FP4 quantization using CuTe-DSL. |
|
Fused Add + RMS normalization + FP4 quantization using CuTe-DSL. |
- class flashinfer.cute_dsl.RMSNormFP4QuantKernel(dtype: Numeric, H: int, block_size: int, output_swizzled: bool, is_fp16: bool, sm_version: int | None = None, scale_format: str | None = None)¶
Fused RMSNorm + FP4 Quantization Kernel.
Key optimizations: 1. Half2/BFloat2 SIMD for max-abs computation 2. Branchless scale clamping via fmin_f32 3. Cluster synchronization for large H dimensions 4. Direct 128-bit vectorized global loads
- __init__(dtype: Numeric, H: int, block_size: int, output_swizzled: bool, is_fp16: bool, sm_version: int | None = None, scale_format: str | None = None)¶
- kernel(mX: Tensor, mW: Tensor, mY: Tensor, mS: Tensor, mGlobalScale: Tensor, M: Int32, eps: Float32, enable_pdl: Constexpr[bool], tv_layout: Layout, tiler_mn: int | Integer | Tuple[Shape, ...])¶
Device kernel with cluster synchronization for large H.
mGlobalScale contains the global scale value. The kernel reads it and computes 1/global_scale, which is multiplied with rstd to apply: y = x * rstd * w / global_scale = rmsnorm(x, w) / global_scale
- class flashinfer.cute_dsl.AddRMSNormFP4QuantKernel(dtype: Numeric, H: int, block_size: int, output_swizzled: bool, is_fp16: bool, sm_version: int | None = None, scale_format: str | None = None, output_both_sf_layouts: bool = False)¶
Fused Add + RMSNorm + FP4 Quantization Kernel.
- Computes:
residual = input + residual (in-place update)
y = RMSNorm(residual) * weight
quantize y to FP4
The residual tensor is modified in-place. Supports both NVFP4 (block_size=16) and MXFP4 (block_size=32) formats.
- __init__(dtype: Numeric, H: int, block_size: int, output_swizzled: bool, is_fp16: bool, sm_version: int | None = None, scale_format: str | None = None, output_both_sf_layouts: bool = False)¶
- kernel(mX: Tensor, mR: Tensor, mW: Tensor, mY: Tensor, mS: Tensor, mS_unswizzled: Tensor, mGlobalScale: Tensor, M: Int32, eps: Float32, enable_pdl: Constexpr[bool], tv_layout: Layout, tiler_mn: int | Integer | Tuple[Shape, ...])¶
Device kernel with cluster sync and Half2 SIMD.
Performs: 1. h = input + residual (writes h back to mR in-place) 2. y = h * rstd * w / global_scale = rmsnorm(h, w) / global_scale 3. quantizes y to FP4
mGlobalScale contains the global scale value. The kernel reads it and computes 1/global_scale, which is multiplied with rstd to apply: y = h * rstd * w / global_scale = rmsnorm(h, w) / global_scale
Attention Wrappers¶
CuTe-DSL implementations of the batch attention wrappers.
- class flashinfer.cute_dsl.attention.wrappers.batch_mla.BatchMLADecodeCuteDSLWrapper(workspace_buffer: Tensor)¶
PyTorch-facing wrapper for the modular MLA decode kernel.
Usage:
wrapper = BatchMLADecodeCuteDSLWrapper(workspace_buffer) wrapper.plan( kv_lora_rank=512, qk_rope_head_dim=64, num_heads=128, page_size=64, q_dtype=torch.bfloat16, ) out = wrapper.run(query, kv_cache, block_tables, seq_lens, max_seq_len, softmax_scale=0.125)
- __init__(workspace_buffer: Tensor) None¶
Bind the wrapper to a user-provided workspace buffer.
- Parameters:
workspace_buffer (torch.Tensor) – Pre-allocated workspace buffer on the target CUDA device. Must have dtype
torch.int8; the size determines the maximum batch this wrapper can handle without re-allocation.
- plan(kv_lora_rank: int = 512, qk_rope_head_dim: int = 64, num_heads: int = 128, page_size: int = 1, q_dtype: dtype = torch.bfloat16, out_dtype: dtype | None = None, is_var_seq: bool = True, enable_pdl: bool | None = None, variant: AttentionVariant | None = None) None¶
Compile (or retrieve cached) MLA decode kernel for the given config.
- Parameters:
kv_lora_rank (int) – Latent dimension (e.g. 512).
qk_rope_head_dim (int) – RoPE dimension (e.g. 64).
num_heads (int) – Number of attention heads (typically 128 for DeepSeek-V3).
page_size (int) – KV cache page size.
q_dtype (torch.dtype) – Query/KV data type (float16 or bfloat16).
out_dtype (Optional[torch.dtype]) – Output data type. Defaults to same as q_dtype.
is_var_seq (bool) – Whether sequence lengths vary across the batch.
enable_pdl (Optional[bool]) – Whether to enable Programmatic Dependent Launch. Auto-detects if None.
variant (Optional[AttentionVariant]) – Attention variant (ALiBi, SoftCapping, AttentionWithSink, etc.). None uses standard softmax attention.
- run(q: Tensor, kv_cache: Tensor, block_tables: Tensor, seq_lens: Tensor, max_seq_len: int, softmax_scale: float, output_scale: float = 1.0, out: Tensor | None = None) Tensor¶
Run the MLA decode kernel.
- Parameters:
q (torch.Tensor) – [B, q_len, H, D_qk] where D_qk = kv_lora_rank + qk_rope_head_dim.
kv_cache (torch.Tensor) – [num_pages, page_size, D_total] (3D) or [num_pages, 1, page_size, D_total] (4D).
block_tables (torch.Tensor) – [B, max_pages] page table indices.
seq_lens (torch.Tensor) – [B] per-request KV sequence lengths.
max_seq_len (int) – Maximum sequence length across the batch.
softmax_scale (float) – Scale factor for QK^T before softmax.
output_scale (float) – Scale factor applied to the output.
out (Optional[torch.Tensor]) – Pre-allocated output [B, q_len, H, kv_lora_rank].
- Returns:
Output tensor [B, q_len, H, kv_lora_rank].
- Return type:
torch.Tensor
- class flashinfer.cute_dsl.attention.wrappers.batch_prefill.BatchPrefillCuteDSLWrapper(float_workspace_buffer: Tensor, use_cuda_graph: bool = False)¶
PyTorch-facing wrapper for the CuTe-DSL ragged-KV batch prefill kernel.
This wrapper exposes a
plan+runAPI compatible withflashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper, but compiles a CuTe-DSL kernel under the hood instead of the C++ FA2/FA3 path.Example
wrapper = BatchPrefillCuteDSLWrapper(workspace_buffer) wrapper.plan(qo_indptr, kv_indptr, num_qo_heads=32, num_kv_heads=8, head_dim_qk=128) out = wrapper.run(q, k, v)
- __init__(float_workspace_buffer: Tensor, use_cuda_graph: bool = False) None¶
Initialise the wrapper and bind it to a workspace buffer.
- Parameters:
float_workspace_buffer (torch.Tensor) – Pre-allocated workspace buffer on the target CUDA device. Named for API parity with
BatchPrefillWithRaggedKVCacheWrapper; callers typically passtorch.uint8. The CuTe-DSL kernel itself does not consume this buffer, but it is retained so the wrapper can mirror the parent API.use_cuda_graph (bool) – Whether the wrapper will be used inside a CUDA graph capture. Defaults to
False.
- plan(qo_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo=None, causal=True, sm_scale=1.0, q_data_type=torch.float16, kv_data_type=torch.float16, window_left: int = -1, variant: AttentionVariant | None = None) None¶
Compile the FMHA prefill kernel for the given configuration.
- Parameters:
qo_indptr (torch.Tensor) – Cumulative query sequence lengths, shape [batch_size + 1].
kv_indptr (torch.Tensor) – Cumulative KV sequence lengths, shape [batch_size + 1].
num_qo_heads (int) – Number of query/output heads.
num_kv_heads (int) – Number of key/value heads (must divide num_qo_heads).
head_dim_qk (int) – Head dimension for queries and keys.
head_dim_vo (Optional[int]) – Head dimension for values and output. Must equal head_dim_qk if set.
causal (bool) – Whether to apply causal masking.
sm_scale (float) – Softmax scale factor (typically 1/sqrt(head_dim)).
q_data_type (torch.dtype) – Data type for queries (float16, bfloat16, or float8_e4m3fn).
kv_data_type (torch.dtype) – Data type for keys/values.
window_left (int) – Sliding window size. -1 disables sliding window.
variant (Optional[AttentionVariant]) – Attention variant (ALiBi, RPE, Sigmoid, etc.). None uses standard softmax.
- run(q: Tensor, k: Tensor, v: Tensor, out: Tensor | None = None) Tensor¶
Run the prefill attention computation.
- Parameters:
q (torch.Tensor) – The query tensor with shape [total_q_len, num_heads, head_dim].
k (torch.Tensor) – The key tensor with shape [total_kv_len, num_heads, head_dim].
v (torch.Tensor) – The value tensor with shape [total_kv_len, num_heads, head_dim].
out (Optional[torch.Tensor], optional) – The output tensor. If None, a new tensor will be created.
- Returns:
The output tensor with shape [total_q_len, num_heads, head_dim].
- Return type:
torch.Tensor