flashinfer.cute_dsl.add_rmsnorm_fp4quant

flashinfer.cute_dsl.add_rmsnorm_fp4quant(input: Tensor, residual: Tensor, weight: Tensor, y_fp4: Tensor | None = None, block_scale: Tensor | None = None, global_scale: Tensor | None = None, eps: float = 1e-06, block_size: int = 16, scale_format: str | None = None, is_sf_swizzled_layout: bool = False, output_both_sf_layouts: bool = False, block_scale_unswizzled: Tensor | None = None, enable_pdl: bool | None = None) Tuple[Tensor, Tensor] | Tuple[Tensor, Tensor, Tensor]

Fused Add + RMS normalization + FP4 quantization using CuTe-DSL.

Computes:
  1. residual = residual + input (in-place update)

  2. y = RMSNorm(residual) * weight

  3. Optionally applies global scaling (y = y / global_scale)

  4. Quantizes y to FP4

The residual tensor is modified in-place to contain the fused value.

Parameters:
  • input (torch.Tensor) – Input tensor, shape (batch_size, hidden_size) or (batch_size, seq_len, hidden_size). Must be torch.float16 or torch.bfloat16. Read-only.

  • residual (torch.Tensor) – Residual tensor. Must have the same shape and dtype as input. Modified in-place to contain residual + input.

  • weight (torch.Tensor) – Weight tensor for RMSNorm, shape (hidden_size,). Must have the same dtype as input.

  • y_fp4 (torch.Tensor, optional) – Output tensor for quantized values in FP4_E2M1 format with dtype torch.float4_e2m1fn_x2. Shape must be (batch_size, hidden_size // 2) or matching 3D input. If None, will be allocated automatically.

  • block_scale (torch.Tensor, optional) –

    Output tensor for per-block scale factors.

    • If is_sf_swizzled_layout=False and output_both_sf_layouts=False: row-major layout with shape (batch_size, hidden_size // block_size) or matching 3D input.

    • If is_sf_swizzled_layout=True or output_both_sf_layouts=True: swizzled layout for efficient tensor core access, with shape (batch_size * hidden_size // block_size,) flattened. The swizzle pattern uses 128x4 tiles where scales are arranged as: [m_tile][k_tile][outer_m (32)][inner_m (4)][inner_k (4)].

    Dtype should be torch.float8_e4m3fn for E4M3 format or torch.uint8 for UE8M0 format. If None, will be allocated automatically.

  • global_scale (torch.Tensor, optional) – Global scale factor tensor of shape (1,) with dtype torch.float32. If provided, the RMSNorm output is divided by this value before quantization: y = rmsnorm(h, w) / global_scale where h = input + residual. This is used for NVFP4 format where a pre-computed global scale lifts per-block scales into optimal dynamic range. If None, no global scaling is applied (equivalent to global_scale=1.0).

  • eps (float) – Epsilon for numerical stability in RMSNorm. Default is 1e-6.

  • block_size (int) –

    Number of elements per quantization block. Default is 16.

    • 16: NVFP4 format with E4M3 scale factors

    • 32: MXFP4 format with UE8M0 scale factors

  • scale_format (str, optional) – Scale factor format: "e4m3" or "ue8m0". If None, auto-selects based on block_size: "e4m3" for block_size=16, "ue8m0" for block_size=32.

  • is_sf_swizzled_layout (bool) – If True, output scale factors in swizzled layout optimized for tensor core GEMM operations. The swizzle uses 128x4 tiles with the pattern: [m_tile_idx * k_tiles * 512 + k_tile_idx * 512 + outer_m * 16 + inner_m * 4 + inner_k] where outer_m = row % 32, inner_m = (row % 128) // 32, etc. Default is False (row-major layout). Note: This parameter is ignored when output_both_sf_layouts=True.

  • output_both_sf_layouts (bool) – If True, return both swizzled and unswizzled scale factors. When enabled, block_scale contains the swizzled layout and block_scale_unswizzled contains the row-major layout. This overrides is_sf_swizzled_layout. Default is False.

  • block_scale_unswizzled (torch.Tensor, optional) – Output tensor for unswizzled per-block scale factors (row-major layout). Only used when output_both_sf_layouts=True. Shape is (batch_size, hidden_size // block_size) or matching 3D input. Dtype should be torch.float8_e4m3fn for E4M3 format or torch.uint8 for UE8M0 format. If None, will be allocated automatically when output_both_sf_layouts=True.

  • enable_pdl (bool, optional) – Whether to launch with Programmatic Dependent Launch (PDL). When None (default) or True, PDL is enabled only if the current device supports it (probed via device_support_pdl()). Pass False to force-disable. See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization

Returns:

When output_both_sf_layouts=False:

A tuple of (y_fp4, block_scale):

  • y_fp4: Quantized FP4 values packed as uint8.

  • block_scale: Per-block scale factors (swizzled or row-major based on is_sf_swizzled_layout).

When output_both_sf_layouts=True:

A tuple of (y_fp4, block_scale, block_scale_unswizzled):

  • y_fp4: Quantized FP4 values packed as uint8.

  • block_scale: Per-block scale factors in swizzled layout.

  • block_scale_unswizzled: Per-block scale factors in row-major layout.

Return type:

Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]

Notes

  • Requires SM100+ (Blackwell) for FP4 quantization PTX intrinsics.

  • For block_size=16 (NVFP4): uses E4M3 scale factors (max value 448.0).

  • For block_size=32 (MXFP4): uses UE8M0 scale factors (power-of-2 scales).

  • FP4 E2M1 format has a max representable value of 6.0.