flashinfer.quantization.nvfp4_batched_quantize

flashinfer.quantization.nvfp4_batched_quantize(a, a_global_sf, sf_vec_size=16)

Quantize batched input tensor to NVFP4 format.

Parameters:
  • a (torch.Tensor) – Input tensor of shape [B, M, K] with dtype fp16/bf16.

  • a_global_sf (torch.Tensor) – Global scale factor of shape [1] with dtype float32.

  • sf_vec_size (int) – Scale-factor vector size. Defaults to 16.

Returns:

(x_q, sf) where x_q has shape [B, M, K/2] with dtype FLOAT4_E2M1X2 and sf is the per-batch swizzled scale-factor tensor of shape [B, ceil(M / 128) * 128 * ceil(K / sf_vec_size / 4) * 4] (M is padded to a multiple of 128 and K / sf_vec_size is rounded up to a multiple of 4).

Return type:

Tuple[torch.Tensor, torch.Tensor]