flashinfer.grouped_mm.grouped_mm_bf16

flashinfer.grouped_mm.grouped_mm_bf16(a: Tensor, b: Tensor, m_indptr: Tensor, out: Tensor | None = None, out_dtype: dtype = torch.bfloat16, *, backend: str = 'cudnn', tactic: int = -1) Tensor

Grouped matrix multiplication with BF16/FP16 data types (cuDNN MOE backend).

Performs a grouped GEMM across experts, as used in Mixture-of-Experts layers. Mirrors flashinfer.mm_bf16() but for expert-partitioned inputs.

\[\text{out}[\text{start}:\text{end}] = a[\text{start}:\text{end}] \times b[e]^T \quad \text{for each expert } e\]

where start, end = m_indptr[e], m_indptr[e+1].

Parameters:
  • a (torch.Tensor) – Token activations, shape (cum_m, k), bf16 or fp16.

  • b (torch.Tensor) – Expert weights, shape (batch_size, n, k).

  • m_indptr (torch.Tensor) – Cumulative token counts, shape (batch_size + 1,), int32.

  • out (Optional[torch.Tensor]) – Pre-allocated output (m_out, n).

  • out_dtype (torch.dtype) – Output data type. torch.bfloat16 (default) or torch.float16, torch.float32.

  • backend (str) – Backend selector. Currently only "cudnn" is supported.

  • tactic (int) – cuDNN execution-plan index. -1 (default) uses the heuristic-best plan; non-negative values select a specific plan.

Returns:

Output tensor (m_out, n).

Return type:

torch.Tensor

Examples

>>> import torch, flashinfer
>>> E, tpe, k, n = 8, 128, 4096, 2048
>>> a = torch.randn(E * tpe, k, dtype=torch.bfloat16, device="cuda")
>>> b = torch.randn(E, n, k, dtype=torch.bfloat16, device="cuda")
>>> m_indptr = (torch.arange(E + 1, device="cuda") * tpe).to(torch.int32)
>>> out = flashinfer.grouped_mm.grouped_mm_bf16(a, b, m_indptr)