flashinfer.quantization.mxfp8_grouped_quantize¶
- flashinfer.quantization.mxfp8_grouped_quantize(a: Tensor, mask: Tensor) Tuple[Tensor, Tensor]¶
Quantize grouped inputs to MXFP8 with UE8M0 block scales.
- Parameters:
a (torch.Tensor) – Input tensor of shape
[B, M, K]with dtypefloat16orbfloat16.mask (torch.Tensor) – Int32 CUDA tensor of shape
[B]. Each value gives the number of valid rows to quantize for the corresponding group, and must satisfy0 <= mask[i] <= M. This precondition is the caller’s responsibility: it is not validated at runtime, because reading the device-sidemaskvalues would force a host synchronization and break CUDA-graph capture. Out-of-range values are undefined behavior. The kernel writes scale factors with bounds checking disabled, somask[i] > Mcorrupts neighboring groups or writes out of bounds.
- Returns:
(x_q, sf)wherex_qhas logical shape[M, padded_K, B]with dtypefloat8_e4m3fnandsfhas logical shape[32, 4, padded_M // 128, 4, padded_K // 128, B]with dtypeuint8.padded_KroundsKup to a multiple of 128. The physical layouts are grouped byBand then permuted to match FlashInfer masked grouped GEMM conventions.Only the first
mask[i]rows of groupiare written; rows>= mask[i](and their scale factors) are unspecified. The consumer must use the samemaskand read only the valid rows.- Return type:
Tuple[torch.Tensor, torch.Tensor]