flashinfer.grouped_mm.grouped_mm_fp4¶
- flashinfer.grouped_mm.grouped_mm_fp4(a: Tensor, b: Tensor, a_descale: Tensor, b_descale: Tensor, m_indptr: Tensor, alpha: Tensor | None = None, out: Tensor | None = None, out_dtype: dtype = torch.bfloat16, block_size: int = 16, *, backend: str = 'cudnn', tactic: int = -1) Tensor¶
Grouped matrix multiplication with FP4 data types (cuDNN MOE backend).
Performs a grouped GEMM across experts, as used in Mixture-of-Experts layers. Mirrors
flashinfer.mm_fp4()but for expert-partitioned inputs.\[\text{out}[\text{start}:\text{end}] = a[\text{start}:\text{end}] \times b[e]^T \quad \text{for each expert } e\]where
start, end = m_indptr[e], m_indptr[e+1].- Parameters:
a (torch.Tensor) – Token activations, shape
(cum_m, k), fp4 e2m1fn_x2 or uint8.b (torch.Tensor) – Expert weights, shape
(batch_size, n, k), fp4 e2m1fn_x2 or uint8.a_descale (torch.Tensor) – Block scale tensor for A. Can be: - 2D swizzled 128x4: shape (cum_m, k // block_size) dtype: float8_e4m3fn or uint8.
b_descale (torch.Tensor) – Block scale tensor for B. Can be: - 3D swizzled 128x4: shape (batch_size, n, k // block_size) dtype: float8_e4m3fn or uint8.
m_indptr (torch.Tensor) – Cumulative token counts, shape
(batch_size + 1,),int32.alpha (Optional[torch.Tensor]) – Scaling factor for the output, shape
(1,).out (Optional[torch.Tensor]) – Pre-allocated output
(m_out, n).out_dtype (torch.dtype) – Output data type.
torch.bfloat16(default) ortorch.float16,torch.float32.block_size (int) – Block size used for the FP4 scale layout.
16selects NVFP4 (withfloat8_e4m3fnscales) and32selects MXFP4 (withuint8scales). Defaults to16.backend (str) – Backend selector. Currently only
"cudnn"is supported.tactic (int) – cuDNN execution-plan index.
-1(default) uses the heuristic-best plan; non-negative values select a specific plan.
- Returns:
Output tensor
(m_out, n).- Return type:
torch.Tensor