flashinfer.grouped_mm

Grouped matrix multiplication APIs for Mixture-of-Experts (MoE) layers, where each expert holds its own weight matrix and tokens are routed to experts via an m_indptr cumulative-count tensor.

The functions in this module mirror the dense flashinfer.gemm.mm_* APIs and currently dispatch to the cuDNN MoE backend.

BF16 / FP16

grouped_mm_bf16(a, b, m_indptr[, out, ...])

Grouped matrix multiplication with BF16/FP16 data types (cuDNN MOE backend).

FP8

grouped_mm_fp8(a, b, m_indptr[, alpha, out, ...])

Grouped matrix multiplication with FP8 data types (cuDNN MOE backend).

MXFP8

grouped_mm_mxfp8(a, b, a_descale, b_descale, ...)

Grouped matrix multiplication with MXFP8 data types (cuDNN MOE backend).

FP4 (NVFP4 / MXFP4)

grouped_mm_fp4(a, b, a_descale, b_descale, ...)

Grouped matrix multiplication with FP4 data types (cuDNN MOE backend).