flashinfer.fused_moe.fused_topk_deepseek

flashinfer.fused_moe.fused_topk_deepseek(scores: Tensor, bias: Tensor, n_group: int, topk_group: int, topk: int, routed_scaling_factor: float, topk_values: Tensor, topk_indices: Tensor, launch_with_pdl: bool = True, routing_replay_out: Tensor | None = None) None

Fused expert routing with top-k selection for DeepSeek-V3.

Performs a highly optimized fused routing operation designed for DeepSeek-V3’s Mixture-of-Experts architecture with grouped expert routing and no auxiliary loss. Combines score computation, expert selection, and normalization into a single kernel:

  1. Compute biased scores sigmoid(scores) + bias.

  2. Group experts and compute per-group scores (sum of top-2 experts per group).

  3. Select the top topk_group groups by group score.

  4. From the selected groups, pick the top topk experts by biased score.

  5. Normalize the selected expert weights: sigmoid_scores / sum(sigmoid_scores) * routed_scaling_factor.

Parameters:
  • scores (torch.Tensor) – Router logits of shape (num_tokens, num_experts), before any activation. bfloat16 / float16 / float32.

  • bias (torch.Tensor) – Per-expert routing bias of shape (num_experts,), same dtype as scores. Added to the sigmoid-activated scores before grouping.

  • n_group (int) – Number of expert groups. Must satisfy n_group <= 32 and num_experts % n_group == 0. Typical value is 8 for DeepSeek-V3 with 256 experts (32 experts per group).

  • topk_group (int) – Number of top groups to select. Must satisfy topk_group <= n_group and topk_group * n_group >= topk. Typical value is 4.

  • topk (int) –

    Number of top experts to select per token. Must be <= num_experts. Hard cap topk <= 32; in addition both branches of the kernel require topk <= 8. Typical value is 8.

    Further per-branch constraints:

    • When n_group > 1: num_experts / n_group <= 32 and (num_experts / n_group) * topk_group <= 128.

    • When n_group == 1: num_experts <= 384.

  • routed_scaling_factor (float) – Scaling factor applied to the normalized expert weights (see step 5 in the algorithm summary above).

  • topk_values (torch.Tensor) – Pre-allocated output tensor of shape (num_tokens, topk). Must have the same dtype as scores (bfloat16 / float16 / float32); the normalized expert weights are written here in place.

  • topk_indices (torch.Tensor) – Pre-allocated output tensor of shape (num_tokens, topk). Must be int32. The selected expert indices are written here in place.

  • launch_with_pdl (bool) – Whether to launch the kernel with Programmatic Dependent Launch. Defaults to True.

  • routing_replay_out (Optional[torch.Tensor]) – Pre-allocated int16 tensor used to record the selected expert IDs. Shape must satisfy shape[0] >= num_tokens and shape[1] == topk — the >= on shape[0] is intentional so the same buffer can be sized for the maximum batch and reused across steps with smaller num_tokens under CUDA graphs (the kernel only writes indices [0, num_tokens)). When None (default) the kernel skips this write (zero overhead).

Returns:

Results are written in place to topk_values and topk_indices (and optionally routing_replay_out).

Return type:

None

Notes

The kernel uses float32 internally for numerical precision regardless of the input dtype. Supported on Ada (SM89), Hopper (SM90), and Blackwell (SM100/SM103/SM120/SM121). In the underlying CUDA kernel name NoAuxTc, the NoAux prefix indicates the absence of auxiliary load-balancing losses and the Tc suffix indicates Tensor-Core utilization.