flashinfer.gemm.mm_M1_16_K7168_N128¶
- flashinfer.gemm.mm_M1_16_K7168_N128(mat_a: Tensor, mat_b: Tensor, out: Tensor, launch_with_pdl: bool = True) None¶
Optimized GEMM for the router operation in Mistral Large 3.
Performs a highly optimized matrix multiplication specifically tailored for the expert routing GEMM in Mistral Large 3’s Mixture-of-Experts (MoE) architecture. Computes
out = mat_a @ mat_bwheremat_ais a small batch of token embeddings (1-16 rows) andmat_bis the expert routing weight matrix. Specialized for the dimensions used in Mistral Large 3 MoE (K = 7168,N = 128).- Parameters:
mat_a (torch.Tensor) – Input token embeddings of shape
(M, K)whereMis the number of tokens (1-16) andKis the hidden dimension (7168). Must be bfloat16, row-major (contiguous).mat_b (torch.Tensor) – Expert routing weights of shape
(K, N)whereNis the number of experts (128). Must be bfloat16, column-major (transposed layout).out (torch.Tensor) – Pre-allocated output tensor of shape
(M, N)containing the routing scores. Must be bfloat16, row-major (contiguous). Mutated in place.launch_with_pdl (bool) – Whether to launch the kernel using Programmatic Dependent Launch. Defaults to
True.
Notes
Requires Blackwell SM100/SM103 architecture. The specialized problem-size optimization makes this significantly faster than general-purpose GEMM implementations for the router op. Raises
ValueErrorif tensor dimensions, strides, or dtypes do not match the expected Mistral Large 3 configuration.