flashinfer.gemm.mm_M1_16_K6144_N256

flashinfer.gemm.mm_M1_16_K6144_N256(mat_a: Tensor, mat_b: Tensor, out: Tensor, launch_with_pdl: bool = True) None

Optimized GEMM for the router operation in GLM-MoE-DSA.

Performs a highly optimized matrix multiplication specifically tailored for the expert routing GEMM in GLM-MoE-DSA’s Mixture-of-Experts (MoE) architecture. Computes out = mat_a @ mat_b where mat_a is a small batch of token embeddings (1-16 rows) and mat_b is the expert routing weight matrix. Specialized for the dimensions used in GLM-MoE-DSA (K = 6144, N = 256).

Parameters:
  • mat_a (torch.Tensor) – Input token embeddings of shape (M, K) where M is the number of tokens (1-16) and K is the hidden dimension (6144). Must be bfloat16, row-major (contiguous).

  • mat_b (torch.Tensor) – Expert routing weights of shape (K, N) where N is the number of experts (256). Must be bfloat16, column-major (transposed layout).

  • out (torch.Tensor) – Pre-allocated output tensor of shape (M, N) containing the routing scores. Must be float32, row-major (contiguous). Mutated in place.

  • launch_with_pdl (bool) – Whether to launch the kernel using Programmatic Dependent Launch. Defaults to True.

Notes

Requires Blackwell SM100/SM103 architecture. The specialized problem-size optimization makes this significantly faster than general-purpose GEMM implementations for the router op. Raises ValueError if tensor dimensions, strides, or dtypes do not match the expected GLM-MoE-DSA configuration.