flashinfer.grouped_mm.moe_gemm_mxfp8_nt_groupwise¶
- flashinfer.grouped_mm.moe_gemm_mxfp8_nt_groupwise(a: Tensor, b: Tensor, a_scale: Tensor, b_scale: Tensor, m_indptr: Tensor, scale_granularity_mnk: Tuple[int, int, int] = (1, 1, 128), scale_major_mode: Literal['MN'] = 'MN', backend: Literal['cute'] = 'cute', out: Tensor | None = None, out_dtype: dtype | None = None) Tensor¶
Perform grouped GEMM with MXFP8 inputs in zero-padding mode using groupwise UE8M0 scaling. Currently only supported on NVIDIA RTX PRO 6000 Blackwell (SM120) architecture.
Zero-padding mode accepts token-packed input
a(no per-expert pre-padding along M) with 4-row per-expert padding on the scale tensora_scale. The group descriptor is a CSR cumsumm_indptr. This mode is optimized for decoding with small per-expert M (down tom_per_expert = 1) where DeepGEMM-style contiguous padding would waste memory and compute.- Parameters:
a (torch.Tensor) – Row-major input tensor shape
(cum_m, k), data type istorch.float8_e4m3fn. Token-packed across experts;cum_mis the cumulative sum of segment lengths.b (torch.Tensor) – Column-major input tensor shape
(num_experts, n, k), data type istorch.float8_e4m3fn.a_scale (torch.Tensor) – Int32-packed UE8M0 scale tensor for
a(4 UE8M0 scales packed per int32), shape(m_padded, k_align)wherem_padded = (cum_m + num_experts * 3) // 4 * 4andk_align = (k + 4 * k_granularity - 1) // (4 * k_granularity). Data type istorch.int32.b_scale (torch.Tensor) – Int32-packed UE8M0 scale tensor for
bin per-token layout, shape(num_experts, n, k_align). Data type istorch.int32. See Notes for the per-token layout requirement.m_indptr (torch.Tensor) – The indptr of the segment lengths, shape
(num_experts + 1,), data type istorch.int32.m_indptr[0] = 0,m_indptr[num_experts] = cum_m.scale_granularity_mnk (Tuple[int, int, int]) – The granularity of the scale tensor,
(m_granularity, n_granularity, k_granularity). Accepted values:(1, 1, 128)(DeepGEMM-style production, default) or(1, 1, 32)(OCP MXFP8).m_granularityandn_granularitymust both be1(per-token scaling along M and N);k_granularitymust be32or128. Anything else raisesValueError.backend (Literal["cute"]) – Backend selector. Currently only
"cute"is implemented.out (Optional[torch.Tensor]) – The output tensor, shape
(cum_m, n). If not specified, an output tensor will be created.out_dtype (Optional[torch.dtype]) – The data type of the output tensor. Currently only
torch.bfloat16is supported.
- Returns:
out – The output tensor, shape
(cum_m, n).- Return type:
torch.Tensor
Notes
MXFP8 uses UE8M0 scales over K-axis blocks of size 32 (OCP spec) or 128 (DeepGEMM convention). Both
a_scaleandb_scalemust be provided in per-token layout: one UE8M0 scale per row along M (fora) or N (forb), packed 4 scales per int32 along the K-axis blocks.If a caller starts from a 2D
(k_granularity, k_granularity)block-quantizedb_scale, it must be broadcast to per-token shape(num_experts, n, k_align)before invoking this function (one scale per N-row).