flashinfer.quantization.scaled_fp4_grouped_quantize

flashinfer.quantization.scaled_fp4_grouped_quantize(a, mask, a_global_sf)

Quantize a batched input tensor to NVFP4 with a per-row mask.

Parameters:
  • a (torch.Tensor) – Input tensor of shape [B, M, K] with dtype fp16/bf16.

  • mask (torch.Tensor) – Mask tensor applied before quantization.

  • a_global_sf (torch.Tensor) – Global scale factor of shape [1] with dtype float32.

Returns:

(x_q, sf) where x_q has logical shape [M, K/2, B] with dtype FLOAT4_E2M1X2 (the implementation permutes the [B, M, K/2] physical layout so the batch dim is last, as required by FlashInfer’s masked grouped GEMM), and sf is the 6D swizzled scale-factor tensor of logical shape [32, 4, padded_M // 128, 4, padded_K // 64, B] viewed as float8_e4m3fn. padded_M rounds M up to a multiple of 128 and padded_K rounds K // sf_vec_size (with sf_vec_size = 16) up to a multiple of 4.

Return type:

Tuple[torch.Tensor, torch.Tensor]