flashinfer.cute_dsl.rmsnorm_fp4quant

flashinfer.cute_dsl.rmsnorm_fp4quant(input: Tensor, weight: Tensor, y_fp4: Tensor | None = None, block_scale: Tensor | None = None, global_scale: Tensor | None = None, eps: float = 1e-06, block_size: int = 16, scale_format: str | None = None, is_sf_swizzled_layout: bool = False, enable_pdl: bool | None = None) Tuple[Tensor, Tensor]

Fused RMS normalization with FP4 quantization using CuTe-DSL.

Computes: y = RMSNorm(input) * weight, optionally applies global scaling (y = y / global_scale), then quantizes y to FP4.

Parameters:
  • input (torch.Tensor) – Input tensor, shape (batch_size, hidden_size) or (batch_size, seq_len, hidden_size). Must be torch.float16 or torch.bfloat16.

  • weight (torch.Tensor) – Weight tensor for RMSNorm, shape (hidden_size,). Must have the same dtype as input.

  • y_fp4 (torch.Tensor, optional) – Output tensor for quantized values in FP4_E2M1 format with dtype torch.float4_e2m1fn_x2. Shape must be (batch_size, hidden_size // 2) or matching 3D input. If None, will be allocated automatically.

  • block_scale (torch.Tensor, optional) –

    Output tensor for per-block scale factors.

    • If is_sf_swizzled_layout=False (default): row-major layout with shape (batch_size, hidden_size // block_size) or matching 3D input.

    • If is_sf_swizzled_layout=True: swizzled layout for efficient tensor core access, with shape (batch_size * hidden_size // block_size,) flattened. The swizzle pattern uses 128x4 tiles where scales are arranged as: [m_tile][k_tile][outer_m (32)][inner_m (4)][inner_k (4)].

    Dtype should be torch.float8_e4m3fn for E4M3 format or torch.uint8 for UE8M0 format. If None, will be allocated automatically.

  • global_scale (torch.Tensor, optional) – Global scale factor tensor of shape (1,) with dtype torch.float32. If provided, the RMSNorm output is divided by this value before quantization: y = rmsnorm(x, w) / global_scale. This is used for NVFP4 format where a pre-computed global scale lifts per-block scales into optimal dynamic range. If None, no global scaling is applied (equivalent to global_scale=1.0).

  • eps (float) – Epsilon for numerical stability in RMSNorm. Default is 1e-6.

  • block_size (int) –

    Number of elements per quantization block. Default is 16.

    • 16: NVFP4 format with E4M3 scale factors

    • 32: MXFP4 format with UE8M0 scale factors

  • scale_format (str, optional) – Scale factor format: "e4m3" or "ue8m0". If None, auto-selects based on block_size: "e4m3" for block_size=16, "ue8m0" for block_size=32.

  • is_sf_swizzled_layout (bool) – If True, output scale factors in swizzled layout optimized for tensor core GEMM operations. The swizzle uses 128x4 tiles with the pattern: [m_tile_idx * k_tiles * 512 + k_tile_idx * 512 + outer_m * 16 + inner_m * 4 + inner_k] where outer_m = row % 32, inner_m = (row % 128) // 32, etc. Default is False (row-major layout).

  • enable_pdl (bool, optional) – Whether to launch with Programmatic Dependent Launch (PDL). When None (default) or True, PDL is enabled only if the current device supports it (probed via device_support_pdl()). Pass False to force-disable. See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization

Returns:

A tuple of (y_fp4, block_scale):

  • y_fp4: Quantized FP4 values packed as uint8.

  • block_scale: Per-block scale factors.

Return type:

Tuple[torch.Tensor, torch.Tensor]

Notes

  • Requires SM100+ (Blackwell) for FP4 quantization PTX intrinsics.

  • For block_size=16 (NVFP4): uses E4M3 scale factors (max value 448.0).

  • For block_size=32 (MXFP4): uses UE8M0 scale factors (power-of-2 scales).

  • FP4 E2M1 format has a max representable value of 6.0.