flashinfer.fused_moe.b12x_fused_moe¶
- flashinfer.fused_moe.b12x_fused_moe(x: Tensor, w1_weight: Tensor, w1_weight_sf: Tensor, w2_weight: Tensor, w2_weight_sf: Tensor, token_selected_experts: Tensor, token_final_scales: Tensor, num_experts: int, top_k: int, *, w1_alpha: Tensor, w2_alpha: Tensor, fc2_input_scale: Tensor | None = None, num_local_experts: int | None = None, output: Tensor | None = None, output_dtype: dtype = torch.bfloat16, activation: str = 'silu', activation_precision: str = 'fp4', quant_mode: str | None = None, source_format: str = 'modelopt') Tensor¶
Run fused MoE on SM120/SM121 using b12x CuTe-DSL kernels.
The kernel takes bf16 input and runs routing, FC1, activation, FC2, and scatter through the selected backend. Automatically selects the micro (decode), static, or dynamic backend based on the routed row count.
- Parameters:
x (torch.Tensor) – Input activations of shape
[num_tokens, hidden_size],bfloat16.w1_weight (torch.Tensor) – FC1 weights, FP4 packed. Gated (SiLU) layout
[E, 2 * intermediate_size, hidden_size // 2]; non-gated (ReLU2) layout[E, intermediate_size, hidden_size // 2].w1_weight_sf (torch.Tensor) – Scale factors for
w1_weight.w2_weight (torch.Tensor) – FC2 weights of shape
[E, hidden_size, intermediate_size // 2], FP4.w2_weight_sf (torch.Tensor) – Scale factors for
w2_weight.token_selected_experts (torch.Tensor) – Expert assignments of shape
[num_tokens, top_k].token_final_scales (torch.Tensor) – Routing weights of shape
[num_tokens, top_k].num_experts (int) – Total number of experts.
top_k (int) – Number of experts routed to per token.
w1_alpha (torch.Tensor) – Per-expert global scale for FC1.
w2_alpha (torch.Tensor) – Per-expert global scale for FC2.
fc2_input_scale (Optional[torch.Tensor]) – Global scale for FC2 input quantization. Required for
quant_mode="nvfp4"; accepted but ignored forquant_mode="w4a16".num_local_experts (Optional[int]) – Local experts for expert parallelism. Defaults to
num_experts.output (Optional[torch.Tensor]) – Pre-allocated output buffer of shape
[num_tokens, hidden_size],bfloat16.output_dtype (torch.dtype) – Output data type. Only
torch.bfloat16is currently supported.activation (str) – Activation function —
"silu"(gated SwiGLU) or"relu2"(non-gated Nemotron-Super). Defaults to"silu".activation_precision (str) – Backward-compatible alias for
quant_mode."fp4"selectsquant_mode="nvfp4";"bf16"selectsquant_mode="w4a16".quant_mode (Optional[str]) – Quantization mode,
"nvfp4"/"w4a4"or"w4a16". When set, selects the backend and internal workspace family.source_format (str) – Source weight format for
quant_mode="w4a16"—"modelopt"or"compressed_tensors". Defaults to"modelopt".
- Returns:
Output tensor of shape
[num_tokens, hidden_size].- Return type:
torch.Tensor