flashinfer.grouped_mm.moe_gemm_mxfp8_nt_groupwise

flashinfer.grouped_mm.moe_gemm_mxfp8_nt_groupwise(a: Tensor, b: Tensor, a_scale: Tensor, b_scale: Tensor, m_indptr: Tensor, scale_granularity_mnk: Tuple[int, int, int] = (1, 1, 128), scale_major_mode: Literal['MN'] = 'MN', backend: Literal['cute'] = 'cute', out: Tensor | None = None, out_dtype: dtype | None = None) Tensor

Perform grouped GEMM with MXFP8 inputs in zero-padding mode using groupwise UE8M0 scaling. Currently only supported on NVIDIA RTX PRO 6000 Blackwell (SM120) architecture.

Zero-padding mode accepts token-packed input a (no per-expert pre-padding along M) with 4-row per-expert padding on the scale tensor a_scale. The group descriptor is a CSR cumsum m_indptr. This mode is optimized for decoding with small per-expert M (down to m_per_expert = 1) where DeepGEMM-style contiguous padding would waste memory and compute.

Parameters:
  • a (torch.Tensor) – Row-major input tensor shape (cum_m, k), data type is torch.float8_e4m3fn. Token-packed across experts; cum_m is the cumulative sum of segment lengths.

  • b (torch.Tensor) – Column-major input tensor shape (num_experts, n, k), data type is torch.float8_e4m3fn.

  • a_scale (torch.Tensor) – Int32-packed UE8M0 scale tensor for a (4 UE8M0 scales packed per int32), shape (m_padded, k_align) where m_padded = (cum_m + num_experts * 3) // 4 * 4 and k_align = (k + 4 * k_granularity - 1) // (4 * k_granularity). Data type is torch.int32.

  • b_scale (torch.Tensor) – Int32-packed UE8M0 scale tensor for b in per-token layout, shape (num_experts, n, k_align). Data type is torch.int32. See Notes for the per-token layout requirement.

  • m_indptr (torch.Tensor) – The indptr of the segment lengths, shape (num_experts + 1,), data type is torch.int32. m_indptr[0] = 0, m_indptr[num_experts] = cum_m.

  • scale_granularity_mnk (Tuple[int, int, int]) – The granularity of the scale tensor, (m_granularity, n_granularity, k_granularity). Accepted values: (1, 1, 128) (DeepGEMM-style production, default) or (1, 1, 32) (OCP MXFP8). m_granularity and n_granularity must both be 1 (per-token scaling along M and N); k_granularity must be 32 or 128. Anything else raises ValueError.

  • backend (Literal["cute"]) – Backend selector. Currently only "cute" is implemented.

  • out (Optional[torch.Tensor]) – The output tensor, shape (cum_m, n). If not specified, an output tensor will be created.

  • out_dtype (Optional[torch.dtype]) – The data type of the output tensor. Currently only torch.bfloat16 is supported.

Returns:

out – The output tensor, shape (cum_m, n).

Return type:

torch.Tensor

Notes

  • MXFP8 uses UE8M0 scales over K-axis blocks of size 32 (OCP spec) or 128 (DeepGEMM convention). Both a_scale and b_scale must be provided in per-token layout: one UE8M0 scale per row along M (for a) or N (for b), packed 4 scales per int32 along the K-axis blocks.

  • If a caller starts from a 2D (k_granularity, k_granularity) block-quantized b_scale, it must be broadcast to per-token shape (num_experts, n, k_align) before invoking this function (one scale per N-row).