flashinfer.grouped_mm.grouped_mm_mxfp8

flashinfer.grouped_mm.grouped_mm_mxfp8(a: Tensor, b: Tensor, a_descale: Tensor, b_descale: Tensor, m_indptr: Tensor, out: Tensor | None = None, out_dtype: dtype = torch.bfloat16, *, backend: str = 'cudnn', tactic: int = -1) Tensor

Grouped matrix multiplication with MXFP8 data types (cuDNN MOE backend).

Performs a grouped GEMM across experts, as used in Mixture-of-Experts layers. Mirrors flashinfer.mm_mxfp8() but for expert-partitioned inputs.

\[\text{out}[\text{start}:\text{end}] = a[\text{start}:\text{end}] \times b[e]^T \quad \text{for each expert } e\]

where start, end = m_indptr[e], m_indptr[e+1].

Parameters:
  • a (torch.Tensor) – Token activations, shape (cum_m, k), e4m3 or e5m2.

  • b (torch.Tensor) – Expert weights, shape (batch_size, n, k), e4m3 or e5m2.

  • a_descale (torch.Tensor) – Block scale tensor for A. Can be: - 2D swizzled 128x4: shape (cum_m, k // 32) dtype: uint8.

  • b_descale (torch.Tensor) – Block scale tensor for B. Can be: - 3D swizzled 128x4: shape (batch_size, n, k // 32) dtype: uint8.

  • m_indptr (torch.Tensor) – Cumulative token counts, shape (batch_size + 1,), int32.

  • out (Optional[torch.Tensor]) – Pre-allocated output (m_out, n).

  • out_dtype (torch.dtype) – Output data type. torch.bfloat16 (default) or torch.float16, torch.float32.

  • backend (str) – Backend selector. Currently only "cudnn" is supported.

  • tactic (int) – cuDNN execution-plan index. -1 (default) uses the heuristic-best plan; non-negative values select a specific plan.

Returns:

Output tensor (m_out, n).

Return type:

torch.Tensor