flashinfer.cute_dsl.rmsnorm_fp4quant¶
- flashinfer.cute_dsl.rmsnorm_fp4quant(input: Tensor, weight: Tensor, y_fp4: Tensor | None = None, block_scale: Tensor | None = None, global_scale: Tensor | None = None, eps: float = 1e-06, block_size: int = 16, scale_format: str | None = None, is_sf_swizzled_layout: bool = False, enable_pdl: bool | None = None) Tuple[Tensor, Tensor]¶
Fused RMS normalization with FP4 quantization using CuTe-DSL.
Computes:
y = RMSNorm(input) * weight, optionally applies global scaling (y = y / global_scale), then quantizesyto FP4.- Parameters:
input (torch.Tensor) – Input tensor, shape
(batch_size, hidden_size)or(batch_size, seq_len, hidden_size). Must betorch.float16ortorch.bfloat16.weight (torch.Tensor) – Weight tensor for RMSNorm, shape
(hidden_size,). Must have the same dtype as input.y_fp4 (torch.Tensor, optional) – Output tensor for quantized values in FP4_E2M1 format with dtype
torch.float4_e2m1fn_x2. Shape must be(batch_size, hidden_size // 2)or matching 3D input. IfNone, will be allocated automatically.block_scale (torch.Tensor, optional) –
Output tensor for per-block scale factors.
If
is_sf_swizzled_layout=False(default): row-major layout with shape(batch_size, hidden_size // block_size)or matching 3D input.If
is_sf_swizzled_layout=True: swizzled layout for efficient tensor core access, with shape(batch_size * hidden_size // block_size,)flattened. The swizzle pattern uses 128x4 tiles where scales are arranged as:[m_tile][k_tile][outer_m (32)][inner_m (4)][inner_k (4)].
Dtype should be
torch.float8_e4m3fnfor E4M3 format ortorch.uint8for UE8M0 format. IfNone, will be allocated automatically.global_scale (torch.Tensor, optional) – Global scale factor tensor of shape
(1,)with dtypetorch.float32. If provided, the RMSNorm output is divided by this value before quantization:y = rmsnorm(x, w) / global_scale. This is used for NVFP4 format where a pre-computed global scale lifts per-block scales into optimal dynamic range. IfNone, no global scaling is applied (equivalent to global_scale=1.0).eps (float) – Epsilon for numerical stability in RMSNorm. Default is
1e-6.block_size (int) –
Number of elements per quantization block. Default is
16.16: NVFP4 format with E4M3 scale factors32: MXFP4 format with UE8M0 scale factors
scale_format (str, optional) – Scale factor format:
"e4m3"or"ue8m0". IfNone, auto-selects based onblock_size:"e4m3"for block_size=16,"ue8m0"for block_size=32.is_sf_swizzled_layout (bool) – If
True, output scale factors in swizzled layout optimized for tensor core GEMM operations. The swizzle uses 128x4 tiles with the pattern:[m_tile_idx * k_tiles * 512 + k_tile_idx * 512 + outer_m * 16 + inner_m * 4 + inner_k]whereouter_m = row % 32,inner_m = (row % 128) // 32, etc. Default isFalse(row-major layout).enable_pdl (bool, optional) – Whether to launch with Programmatic Dependent Launch (PDL). When
None(default) orTrue, PDL is enabled only if the current device supports it (probed viadevice_support_pdl()). PassFalseto force-disable. See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization
- Returns:
A tuple of
(y_fp4, block_scale):y_fp4: Quantized FP4 values packed as uint8.block_scale: Per-block scale factors.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
Notes
Requires SM100+ (Blackwell) for FP4 quantization PTX intrinsics.
For block_size=16 (NVFP4): uses E4M3 scale factors (max value 448.0).
For block_size=32 (MXFP4): uses UE8M0 scale factors (power-of-2 scales).
FP4 E2M1 format has a max representable value of 6.0.