flashinfer.quantization.scaled_fp4_grouped_quantize¶
- flashinfer.quantization.scaled_fp4_grouped_quantize(a, mask, a_global_sf)¶
Quantize a batched input tensor to NVFP4 with a per-row mask.
- Parameters:
a (torch.Tensor) – Input tensor of shape
[B, M, K]with dtype fp16/bf16.mask (torch.Tensor) – Mask tensor applied before quantization.
a_global_sf (torch.Tensor) – Global scale factor of shape
[1]with dtypefloat32.
- Returns:
(x_q, sf)wherex_qhas logical shape[M, K/2, B]with dtypeFLOAT4_E2M1X2(the implementation permutes the[B, M, K/2]physical layout so the batch dim is last, as required by FlashInfer’s masked grouped GEMM), andsfis the 6D swizzled scale-factor tensor of logical shape[32, 4, padded_M // 128, 4, padded_K // 64, B]viewed asfloat8_e4m3fn.padded_MroundsMup to a multiple of 128 andpadded_KroundsK // sf_vec_size(withsf_vec_size = 16) up to a multiple of 4.- Return type:
Tuple[torch.Tensor, torch.Tensor]