flashinfer.fused_moe.b12x_fused_moe

flashinfer.fused_moe.b12x_fused_moe(x: Tensor, w1_weight: Tensor, w1_weight_sf: Tensor, w2_weight: Tensor, w2_weight_sf: Tensor, token_selected_experts: Tensor, token_final_scales: Tensor, num_experts: int, top_k: int, *, w1_alpha: Tensor, w2_alpha: Tensor, fc2_input_scale: Tensor | None = None, num_local_experts: int | None = None, output: Tensor | None = None, output_dtype: dtype = torch.bfloat16, activation: str = 'silu', activation_precision: str = 'fp4', quant_mode: str | None = None, source_format: str = 'modelopt') Tensor

Run fused MoE on SM120/SM121 using b12x CuTe-DSL kernels.

The kernel takes bf16 input and runs routing, FC1, activation, FC2, and scatter through the selected backend. Automatically selects the micro (decode), static, or dynamic backend based on the routed row count.

Parameters:
  • x (torch.Tensor) – Input activations of shape [num_tokens, hidden_size], bfloat16.

  • w1_weight (torch.Tensor) – FC1 weights, FP4 packed. Gated (SiLU) layout [E, 2 * intermediate_size, hidden_size // 2]; non-gated (ReLU2) layout [E, intermediate_size, hidden_size // 2].

  • w1_weight_sf (torch.Tensor) – Scale factors for w1_weight.

  • w2_weight (torch.Tensor) – FC2 weights of shape [E, hidden_size, intermediate_size // 2], FP4.

  • w2_weight_sf (torch.Tensor) – Scale factors for w2_weight.

  • token_selected_experts (torch.Tensor) – Expert assignments of shape [num_tokens, top_k].

  • token_final_scales (torch.Tensor) – Routing weights of shape [num_tokens, top_k].

  • num_experts (int) – Total number of experts.

  • top_k (int) – Number of experts routed to per token.

  • w1_alpha (torch.Tensor) – Per-expert global scale for FC1.

  • w2_alpha (torch.Tensor) – Per-expert global scale for FC2.

  • fc2_input_scale (Optional[torch.Tensor]) – Global scale for FC2 input quantization. Required for quant_mode="nvfp4"; accepted but ignored for quant_mode="w4a16".

  • num_local_experts (Optional[int]) – Local experts for expert parallelism. Defaults to num_experts.

  • output (Optional[torch.Tensor]) – Pre-allocated output buffer of shape [num_tokens, hidden_size], bfloat16.

  • output_dtype (torch.dtype) – Output data type. Only torch.bfloat16 is currently supported.

  • activation (str) – Activation function — "silu" (gated SwiGLU) or "relu2" (non-gated Nemotron-Super). Defaults to "silu".

  • activation_precision (str) – Backward-compatible alias for quant_mode. "fp4" selects quant_mode="nvfp4"; "bf16" selects quant_mode="w4a16".

  • quant_mode (Optional[str]) – Quantization mode, "nvfp4" / "w4a4" or "w4a16". When set, selects the backend and internal workspace family.

  • source_format (str) – Source weight format for quant_mode="w4a16""modelopt" or "compressed_tensors". Defaults to "modelopt".

Returns:

Output tensor of shape [num_tokens, hidden_size].

Return type:

torch.Tensor