flashinfer.gemm.mm_M1_16_K6144_N256¶
- flashinfer.gemm.mm_M1_16_K6144_N256(mat_a: Tensor, mat_b: Tensor, out: Tensor, launch_with_pdl: bool = True) None¶
Optimized GEMM for the router operation in GLM-MoE-DSA.
Performs a highly optimized matrix multiplication specifically tailored for the expert routing GEMM in GLM-MoE-DSA’s Mixture-of-Experts (MoE) architecture. Computes
out = mat_a @ mat_bwheremat_ais a small batch of token embeddings (1-16 rows) andmat_bis the expert routing weight matrix. Specialized for the dimensions used in GLM-MoE-DSA (K = 6144,N = 256).- Parameters:
mat_a (torch.Tensor) – Input token embeddings of shape
(M, K)whereMis the number of tokens (1-16) andKis the hidden dimension (6144). Must be bfloat16, row-major (contiguous).mat_b (torch.Tensor) – Expert routing weights of shape
(K, N)whereNis the number of experts (256). Must be bfloat16, column-major (transposed layout).out (torch.Tensor) – Pre-allocated output tensor of shape
(M, N)containing the routing scores. Must be float32, row-major (contiguous). Mutated in place.launch_with_pdl (bool) – Whether to launch the kernel using Programmatic Dependent Launch. Defaults to
True.
Notes
Requires Blackwell SM100/SM103 architecture. The specialized problem-size optimization makes this significantly faster than general-purpose GEMM implementations for the router op. Raises
ValueErrorif tensor dimensions, strides, or dtypes do not match the expected GLM-MoE-DSA configuration.