flashinfer.quantization.mxfp4_dequantize

flashinfer.quantization.mxfp4_dequantize(a_fp4, a_sf)

Dequantize MXFP4 packed weights back to float32.

Parameters:
  • a_fp4 (torch.Tensor) – Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2).

  • a_sf (torch.Tensor) – UE8M0 scale-factor tensor (uint8); shape depends on the layout and sf_vec_size (this entry point assumes the swizzled buffer produced by mxfp4_quantize() with sf_vec_size = 32).

Returns:

Dequantized tensor of shape [M, K] with dtype float32.

Return type:

torch.Tensor