flashinfer.gemm.mm_M1_16_K7168_N128

flashinfer.gemm.mm_M1_16_K7168_N128(mat_a: Tensor, mat_b: Tensor, out: Tensor, launch_with_pdl: bool = True) None

Optimized GEMM for the router operation in Mistral Large 3.

Performs a highly optimized matrix multiplication specifically tailored for the expert routing GEMM in Mistral Large 3’s Mixture-of-Experts (MoE) architecture. Computes out = mat_a @ mat_b where mat_a is a small batch of token embeddings (1-16 rows) and mat_b is the expert routing weight matrix. Specialized for the dimensions used in Mistral Large 3 MoE (K = 7168, N = 128).

Parameters:
  • mat_a (torch.Tensor) – Input token embeddings of shape (M, K) where M is the number of tokens (1-16) and K is the hidden dimension (7168). Must be bfloat16, row-major (contiguous).

  • mat_b (torch.Tensor) – Expert routing weights of shape (K, N) where N is the number of experts (128). Must be bfloat16, column-major (transposed layout).

  • out (torch.Tensor) – Pre-allocated output tensor of shape (M, N) containing the routing scores. Must be bfloat16, row-major (contiguous). Mutated in place.

  • launch_with_pdl (bool) – Whether to launch the kernel using Programmatic Dependent Launch. Defaults to True.

Notes

Requires Blackwell SM100/SM103 architecture. The specialized problem-size optimization makes this significantly faster than general-purpose GEMM implementations for the router op. Raises ValueError if tensor dimensions, strides, or dtypes do not match the expected Mistral Large 3 configuration.