flashinfer.cute_dsl.add_rmsnorm_fp4quant¶
- flashinfer.cute_dsl.add_rmsnorm_fp4quant(input: Tensor, residual: Tensor, weight: Tensor, y_fp4: Tensor | None = None, block_scale: Tensor | None = None, global_scale: Tensor | None = None, eps: float = 1e-06, block_size: int = 16, scale_format: str | None = None, is_sf_swizzled_layout: bool = False, output_both_sf_layouts: bool = False, block_scale_unswizzled: Tensor | None = None, enable_pdl: bool | None = None) Tuple[Tensor, Tensor] | Tuple[Tensor, Tensor, Tensor]¶
Fused Add + RMS normalization + FP4 quantization using CuTe-DSL.
- Computes:
residual = residual + input(in-place update)y = RMSNorm(residual) * weightOptionally applies global scaling (
y = y / global_scale)Quantizes
yto FP4
The residual tensor is modified in-place to contain the fused value.
- Parameters:
input (torch.Tensor) – Input tensor, shape
(batch_size, hidden_size)or(batch_size, seq_len, hidden_size). Must betorch.float16ortorch.bfloat16. Read-only.residual (torch.Tensor) – Residual tensor. Must have the same shape and dtype as
input. Modified in-place to containresidual + input.weight (torch.Tensor) – Weight tensor for RMSNorm, shape
(hidden_size,). Must have the same dtype as input.y_fp4 (torch.Tensor, optional) – Output tensor for quantized values in FP4_E2M1 format with dtype
torch.float4_e2m1fn_x2. Shape must be(batch_size, hidden_size // 2)or matching 3D input. IfNone, will be allocated automatically.block_scale (torch.Tensor, optional) –
Output tensor for per-block scale factors.
If
is_sf_swizzled_layout=Falseandoutput_both_sf_layouts=False: row-major layout with shape(batch_size, hidden_size // block_size)or matching 3D input.If
is_sf_swizzled_layout=Trueoroutput_both_sf_layouts=True: swizzled layout for efficient tensor core access, with shape(batch_size * hidden_size // block_size,)flattened. The swizzle pattern uses 128x4 tiles where scales are arranged as:[m_tile][k_tile][outer_m (32)][inner_m (4)][inner_k (4)].
Dtype should be
torch.float8_e4m3fnfor E4M3 format ortorch.uint8for UE8M0 format. IfNone, will be allocated automatically.global_scale (torch.Tensor, optional) – Global scale factor tensor of shape
(1,)with dtypetorch.float32. If provided, the RMSNorm output is divided by this value before quantization:y = rmsnorm(h, w) / global_scalewhereh = input + residual. This is used for NVFP4 format where a pre-computed global scale lifts per-block scales into optimal dynamic range. IfNone, no global scaling is applied (equivalent to global_scale=1.0).eps (float) – Epsilon for numerical stability in RMSNorm. Default is
1e-6.block_size (int) –
Number of elements per quantization block. Default is
16.16: NVFP4 format with E4M3 scale factors32: MXFP4 format with UE8M0 scale factors
scale_format (str, optional) – Scale factor format:
"e4m3"or"ue8m0". IfNone, auto-selects based onblock_size:"e4m3"for block_size=16,"ue8m0"for block_size=32.is_sf_swizzled_layout (bool) – If
True, output scale factors in swizzled layout optimized for tensor core GEMM operations. The swizzle uses 128x4 tiles with the pattern:[m_tile_idx * k_tiles * 512 + k_tile_idx * 512 + outer_m * 16 + inner_m * 4 + inner_k]whereouter_m = row % 32,inner_m = (row % 128) // 32, etc. Default isFalse(row-major layout). Note: This parameter is ignored whenoutput_both_sf_layouts=True.output_both_sf_layouts (bool) – If
True, return both swizzled and unswizzled scale factors. When enabled,block_scalecontains the swizzled layout andblock_scale_unswizzledcontains the row-major layout. This overridesis_sf_swizzled_layout. Default isFalse.block_scale_unswizzled (torch.Tensor, optional) – Output tensor for unswizzled per-block scale factors (row-major layout). Only used when
output_both_sf_layouts=True. Shape is(batch_size, hidden_size // block_size)or matching 3D input. Dtype should betorch.float8_e4m3fnfor E4M3 format ortorch.uint8for UE8M0 format. IfNone, will be allocated automatically whenoutput_both_sf_layouts=True.enable_pdl (bool, optional) – Whether to launch with Programmatic Dependent Launch (PDL). When
None(default) orTrue, PDL is enabled only if the current device supports it (probed viadevice_support_pdl()). PassFalseto force-disable. See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization
- Returns:
- When
output_both_sf_layouts=False: A tuple of
(y_fp4, block_scale):y_fp4: Quantized FP4 values packed as uint8.block_scale: Per-block scale factors (swizzled or row-major based onis_sf_swizzled_layout).
- When
output_both_sf_layouts=True: A tuple of
(y_fp4, block_scale, block_scale_unswizzled):y_fp4: Quantized FP4 values packed as uint8.block_scale: Per-block scale factors in swizzled layout.block_scale_unswizzled: Per-block scale factors in row-major layout.
- When
- Return type:
Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
Notes
Requires SM100+ (Blackwell) for FP4 quantization PTX intrinsics.
For block_size=16 (NVFP4): uses E4M3 scale factors (max value 448.0).
For block_size=32 (MXFP4): uses UE8M0 scale factors (power-of-2 scales).
FP4 E2M1 format has a max representable value of 6.0.