flashinfer.grouped_mm.grouped_mm_mxfp8¶
- flashinfer.grouped_mm.grouped_mm_mxfp8(a: Tensor, b: Tensor, a_descale: Tensor, b_descale: Tensor, m_indptr: Tensor, out: Tensor | None = None, out_dtype: dtype = torch.bfloat16, *, backend: str = 'cudnn', tactic: int = -1) Tensor¶
Grouped matrix multiplication with MXFP8 data types (cuDNN MOE backend).
Performs a grouped GEMM across experts, as used in Mixture-of-Experts layers. Mirrors
flashinfer.mm_mxfp8()but for expert-partitioned inputs.\[\text{out}[\text{start}:\text{end}] = a[\text{start}:\text{end}] \times b[e]^T \quad \text{for each expert } e\]where
start, end = m_indptr[e], m_indptr[e+1].- Parameters:
a (torch.Tensor) – Token activations, shape
(cum_m, k), e4m3 or e5m2.b (torch.Tensor) – Expert weights, shape
(batch_size, n, k), e4m3 or e5m2.a_descale (torch.Tensor) – Block scale tensor for A. Can be: - 2D swizzled 128x4: shape (cum_m, k // 32) dtype: uint8.
b_descale (torch.Tensor) – Block scale tensor for B. Can be: - 3D swizzled 128x4: shape (batch_size, n, k // 32) dtype: uint8.
m_indptr (torch.Tensor) – Cumulative token counts, shape
(batch_size + 1,),int32.out (Optional[torch.Tensor]) – Pre-allocated output
(m_out, n).out_dtype (torch.dtype) – Output data type.
torch.bfloat16(default) ortorch.float16,torch.float32.backend (str) – Backend selector. Currently only
"cudnn"is supported.tactic (int) – cuDNN execution-plan index.
-1(default) uses the heuristic-best plan; non-negative values select a specific plan.
- Returns:
Output tensor
(m_out, n).- Return type:
torch.Tensor