flashinfer.fused_moe.trtllm_fp4_block_scale_moe

flashinfer.fused_moe.trtllm_fp4_block_scale_moe(routing_logits: Tensor, routing_bias: Tensor | None, hidden_states: Tensor, hidden_states_scale: Tensor | None, gemm1_weights: Tensor, gemm1_weights_scale: Tensor, gemm1_bias: Tensor | None, gemm1_alpha: Tensor | None, gemm1_beta: Tensor | None, gemm1_clamp_limit: Tensor | None, gemm2_weights: Tensor, gemm2_weights_scale: Tensor, gemm2_bias: Tensor | None, output1_scale_scalar: Tensor | None, output1_scale_gate_scalar: Tensor | None, output2_scale_scalar: Tensor | None, num_experts: int, top_k: int, n_group: int | None, topk_group: int | None, intermediate_size: int, local_expert_offset: int, local_num_experts: int, routed_scaling_factor: float | None, routing_method_type: int = 0, do_finalize: bool = True, enable_pdl: bool | None = None, activation_type: int = 3, per_token_scale: Tensor | None = None, output: Tensor | None = None, tune_max_num_tokens: int = 8192, norm_topk_prob: bool = True, routing_replay_out: Tensor | None = None) List[Tensor]

FP4 block-scaled MoE operation.

Parameters:
  • routing_logits (torch.Tensor) – [seq_len, num_experts] tensor of routing logits. float32 or bfloat16.

  • routing_bias (Optional[torch.Tensor]) – [num_experts] tensor of routing bias. Same dtype as routing_logits; may be None.

  • hidden_states (torch.Tensor) – Hidden states of shape [seq_len, hidden_size // 2] (NVFP4) or [seq_len, hidden_size] (MXFP8 / bfloat16). Supports bfloat16, MXFP8, and NVFP4 (packed into uint8).

  • hidden_states_scale (Optional[torch.Tensor]) – Block scales for MXFP8 / NVFP4 hidden states of shape [seq_len, hidden_size // (32 if mxfp8 else 16)]. Dtype is float8.

  • gemm1_weights (torch.Tensor) – [num_experts, 2 * intermediate_size, hidden_size // 2] packed FP4 FC1 weights, dtype uint8.

  • gemm1_weights_scale (torch.Tensor) – [num_experts, 2 * intermediate_size, hidden_size // (32 if mxfp4 else 16)] FC1 weight block scales, dtype float8.

  • gemm1_bias (Optional[torch.Tensor]) – [num_experts, 2 * intermediate_size] FC1 bias, float32.

  • gemm1_alpha (Optional[torch.Tensor]) – [num_experts] swiglu alpha, float32.

  • gemm1_beta (Optional[torch.Tensor]) – [num_experts] swiglu beta, float32.

  • gemm1_clamp_limit (Optional[torch.Tensor]) – [num_experts] swiglu clamp limit, float32.

  • gemm2_weights (torch.Tensor) – [num_experts, hidden_size, intermediate_size] packed FP4 FC2 weights, dtype uint8.

  • gemm2_weights_scale (torch.Tensor) – [num_experts, hidden_size, intermediate_size // (32 if mxfp4 else 16)] FC2 weight block scales, dtype float8.

  • gemm2_bias (Optional[torch.Tensor]) – [num_experts, hidden_size] FC2 bias, float32.

  • output1_scale_scalar (Optional[torch.Tensor]) – [local_num_experts] scaling factors for the first-layer activation output.

  • output1_scale_gate_scalar (Optional[torch.Tensor]) – [local_num_experts] scaling factors for the first-layer gate output.

  • output2_scale_scalar (Optional[torch.Tensor]) – [local_num_experts] scaling factors for the second-layer output.

  • num_experts (int) – Total number of experts.

  • top_k (int) – Number of experts to route to per token.

  • n_group (Optional[int]) – Number of expert groups.

  • topk_group (Optional[int]) – Number of groups to consider for top-k routing.

  • intermediate_size (int) – Size of the intermediate layer.

  • local_expert_offset (int) – Offset of local experts in the global expert space.

  • local_num_experts (int) – Number of experts handled by this device.

  • routed_scaling_factor (Optional[float]) – Scaling factor for routing.

  • routing_method_type (int) –

    Routing method (default 0). Selects the routing-kernel pipeline; matches flashinfer.tllm_enums.RoutingMethodType.

    • 0 Default — Softmax → TopK.

    • 1 Renormalize — TopK → Softmax.

    • 2 DeepSeekV3 — Sigmoid → RoutingBiasAdd → Top-2 in group → Top-topk_group groups → Top-top_k experts from the selected groups.

    • 3 Llama4 — Top-1 → Sigmoid.

    • 4 RenormalizeNaive — Softmax → TopK → Renormalize (Qwen3 style).

    • 5 TopK — TopK only (no softmax/sigmoid).

    • 6 SigmoidRenorm — Sigmoid → TopK → Renormalize (divide by the sum of the top-K weights).

    • 7 MiniMax2 — Sigmoid + Bias → TopK → ScaledSumNormalize (routeScale = 1.0, epsilon = 1e-20).

    • 8 Sigmoid — Sigmoid → TopK (no renormalization).

    • 9 Unspecified — reserved.

  • do_finalize (bool) – Whether to finalize the output (default True).

  • enable_pdl (Optional[bool]) – Whether to enable Programmatic Dependent Launch.

  • activation_type (int) – Activation type (default 3 — Swiglu). 3 Swiglu; 4 Geglu; 6 Relu2; 7 Identity.

  • per_token_scale (Optional[torch.Tensor]) – [seq_len] per-token scaling factors, float32.

  • output (Optional[torch.Tensor]) – Optional in-place [seq_len, hidden_size] output tensor.

  • tune_max_num_tokens (int) – Maximum number of tokens for autotuning (default 8192).

  • norm_topk_prob (bool) – Whether to normalize the top-k probabilities (default True).

  • routing_replay_out (Optional[torch.Tensor]) – Optional int16 tensor of shape (num_tokens_or_larger, top_k) used to capture the selected expert IDs during routing. Column order matches topk_indices. When None (default) the kernel skips the write entirely. The buffer may be larger than num_tokens for CUDA-graph pre-allocation; only rows [0, num_tokens) are written.

Returns:

[output] when do_finalize is True, otherwise [gemm2_output, expert_weights, expanded_idx_to_permuted_idx].

Return type:

List[torch.Tensor]