flashinfer.fp4_quantization.nvfp4_batched_quantize¶
- flashinfer.fp4_quantization.nvfp4_batched_quantize(a, a_global_sf, sf_vec_size=16, mask=None)¶
Quantize batched input tensor to NVFP4 format.
- Parameters:
a (torch.Tensor) – Input tensor of shape [B, M, K] with dtype fp16/bf16.
mask (torch.Tensor) – Mask tensor to apply before quantization.
a_global_sf (torch.Tensor) – Global scale factor of shape [1] with dtype float32.
sf_vec_size (int, optional) – Scale factor vector size. Defaults to 16.
- Returns:
- A tuple containing:
Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
Scale factors tensor with shape determined by layout and sf_vec_size
- Return type:
Tuple[torch.Tensor, torch.Tensor]