flashinfer.quantization.mxfp8_grouped_quantize

flashinfer.quantization.mxfp8_grouped_quantize(a: Tensor, mask: Tensor) Tuple[Tensor, Tensor]

Quantize grouped inputs to MXFP8 with UE8M0 block scales.

Parameters:
  • a (torch.Tensor) – Input tensor of shape [B, M, K] with dtype float16 or bfloat16.

  • mask (torch.Tensor) – Int32 CUDA tensor of shape [B]. Each value gives the number of valid rows to quantize for the corresponding group, and must satisfy 0 <= mask[i] <= M. This precondition is the caller’s responsibility: it is not validated at runtime, because reading the device-side mask values would force a host synchronization and break CUDA-graph capture. Out-of-range values are undefined behavior. The kernel writes scale factors with bounds checking disabled, so mask[i] > M corrupts neighboring groups or writes out of bounds.

Returns:

(x_q, sf) where x_q has logical shape [M, padded_K, B] with dtype float8_e4m3fn and sf has logical shape [32, 4, padded_M // 128, 4, padded_K // 128, B] with dtype uint8. padded_K rounds K up to a multiple of 128. The physical layouts are grouped by B and then permuted to match FlashInfer masked grouped GEMM conventions.

Only the first mask[i] rows of group i are written; rows >= mask[i] (and their scale factors) are unspecified. The consumer must use the same mask and read only the valid rows.

Return type:

Tuple[torch.Tensor, torch.Tensor]