flashinfer.quantization.kernels.nvfp4_quantize.nvfp4_quantize_per_token_cute_dsl

flashinfer.quantization.kernels.nvfp4_quantize.nvfp4_quantize_per_token_cute_dsl(input: Tensor, global_scale_inv: Tensor, sf_layout: int = 0, enable_pdl: bool | None = None) Tuple[Tensor, Tensor, Tensor]

Per-token NVFP4 activation quantization using the CuTe-DSL kernel.

Unlike nvfp4_quantize_cute_dsl(), which applies a single global scale, this variant computes one quantization scale per row (token) of the activation. Each row is scaled independently so that its largest magnitude maps to the NVFP4 dynamic range, and the resulting per-token scale is returned alongside the packed FP4 output and the E4M3 block scale factors.

  • E4M3 block scale factors (FP8), sf_vec_size = 16

  • E2M1 output format (4-bit, 2 values per byte)

  • Supports 128x4, 8x4, and linear scale-factor layouts

The kernel is compiled once per (K, dtype, sf_layout, pdl) tuple and handles varying M (number of tokens) at runtime without recompilation.

Parameters:
  • input (torch.Tensor) – 2-D activation tensor of shape [M, K] with dtype fp16/bf16. K must be divisible by NVFP4_SF_VEC_SIZE (16).

  • global_scale_inv (torch.Tensor) – Scalar tensor (float32) holding the inverse global scale applied on top of the per-token scale. A Python float is also accepted and wrapped into a tensor internally.

  • sf_layout (int) – Scale-factor layout (0=128x4, 1=8x4, 2=linear).

  • enable_pdl (bool, optional) – Whether to enable Programmatic Dependent Launch. Auto-detected from device capability (SM >= 9.0) when None; pass False to force it off.

Returns:

(fp4_output, scale_output, per_token_scale) where:

  • fp4_output is the packed quantized tensor of shape [M, K/2] with dtype uint8 (two E2M1 values per byte).

  • scale_output holds the E4M3 block scale factors (uint8) reshaped to [padded_rows, padded_sf_cols]. The padding depends on sf_layout: linear keeps M rows, while 128x4 / 8x4 pad rows and columns up to the layout tile.

  • per_token_scale is the per-row quantization scale of shape [M] with dtype float32.

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]