flashinfer.quantization.nvfp4_batched_quantize¶
- flashinfer.quantization.nvfp4_batched_quantize(a, a_global_sf, sf_vec_size=16)¶
Quantize batched input tensor to NVFP4 format.
- Parameters:
a (torch.Tensor) – Input tensor of shape
[B, M, K]with dtype fp16/bf16.a_global_sf (torch.Tensor) – Global scale factor of shape
[1]with dtypefloat32.sf_vec_size (int) – Scale-factor vector size. Defaults to
16.
- Returns:
(x_q, sf)wherex_qhas shape[B, M, K/2]with dtypeFLOAT4_E2M1X2andsfis the per-batch swizzled scale-factor tensor of shape[B, ceil(M / 128) * 128 * ceil(K / sf_vec_size / 4) * 4](M is padded to a multiple of 128 andK / sf_vec_sizeis rounded up to a multiple of 4).- Return type:
Tuple[torch.Tensor, torch.Tensor]