flashinfer.grouped_mm.grouped_mm_bf16¶
- flashinfer.grouped_mm.grouped_mm_bf16(a: Tensor, b: Tensor, m_indptr: Tensor, out: Tensor | None = None, out_dtype: dtype = torch.bfloat16, *, backend: str = 'cudnn', tactic: int = -1) Tensor¶
Grouped matrix multiplication with BF16/FP16 data types (cuDNN MOE backend).
Performs a grouped GEMM across experts, as used in Mixture-of-Experts layers. Mirrors
flashinfer.mm_bf16()but for expert-partitioned inputs.\[\text{out}[\text{start}:\text{end}] = a[\text{start}:\text{end}] \times b[e]^T \quad \text{for each expert } e\]where
start, end = m_indptr[e], m_indptr[e+1].- Parameters:
a (torch.Tensor) – Token activations, shape
(cum_m, k), bf16 or fp16.b (torch.Tensor) – Expert weights, shape
(batch_size, n, k).m_indptr (torch.Tensor) – Cumulative token counts, shape
(batch_size + 1,),int32.out (Optional[torch.Tensor]) – Pre-allocated output
(m_out, n).out_dtype (torch.dtype) – Output data type.
torch.bfloat16(default) ortorch.float16,torch.float32.backend (str) – Backend selector. Currently only
"cudnn"is supported.tactic (int) – cuDNN execution-plan index.
-1(default) uses the heuristic-best plan; non-negative values select a specific plan.
- Returns:
Output tensor
(m_out, n).- Return type:
torch.Tensor
Examples
>>> import torch, flashinfer >>> E, tpe, k, n = 8, 128, 4096, 2048 >>> a = torch.randn(E * tpe, k, dtype=torch.bfloat16, device="cuda") >>> b = torch.randn(E, n, k, dtype=torch.bfloat16, device="cuda") >>> m_indptr = (torch.arange(E + 1, device="cuda") * tpe).to(torch.int32) >>> out = flashinfer.grouped_mm.grouped_mm_bf16(a, b, m_indptr)