flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe¶
- flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(topk_ids: Tensor | Tuple[Tensor, Tensor], routing_bias: Tensor | None, hidden_states: Tensor, hidden_states_scale: Tensor | None, gemm1_weights: Tensor, gemm1_weights_scale: Tensor, gemm1_bias: Tensor | None, gemm1_alpha: Tensor | None, gemm1_beta: Tensor | None, gemm1_clamp_limit: Tensor | None, gemm2_weights: Tensor, gemm2_weights_scale: Tensor, gemm2_bias: Tensor | None, output1_scale_scalar: Tensor | None, output1_scale_gate_scalar: Tensor | None, output2_scale_scalar: Tensor | None, num_experts: int, top_k: int, n_group: int | None, topk_group: int | None, intermediate_size: int, local_expert_offset: int, local_num_experts: int, routed_scaling_factor: float | None, routing_method_type: int = 0, do_finalize: bool = True, enable_pdl: bool | None = None, activation_type: int = 3, per_token_scale: Tensor | None = None, output: Tensor | None = None, tune_max_num_tokens: int = 8192) List[Tensor]¶
FP4 block scale MoE operation with pre-computed routing.
This function supports two pre-computed routing formats: 1. Packed format:
topk_idsis a single int32 tensor with(expert_id << 16) | weightentries (high 16 bits = int16 expert id, low 16 bits = float16/bfloat16 weight, matchingPackedScoreIdxininclude/flashinfer/trtllm/fused_moe/RoutingKernel.h).Unpacked format:
topk_idsis a tuple(topk_ids, topk_weights).
- Parameters:
topk_ids (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) – Pre-computed routing decision. Either a single int32 tensor of shape
[seq_len, top_k]in packed format(expert_id << 16) | weightor a tuple(ids, weights)whereidsis int32 of shape[seq_len, top_k](plain expert indices) andweightsisbfloat16of the same shape (routing weights).routing_bias (Optional[torch.Tensor]) –
[num_experts]routing bias. May beNone.hidden_states (torch.Tensor) – Hidden states of shape
[seq_len, hidden_size // 2](NVFP4) or[seq_len, hidden_size](MXFP8 / bfloat16).hidden_states_scale (Optional[torch.Tensor]) –
[seq_len, hidden_size // (32 if mxfp8 else 16)]block scales of the hidden states, float8.gemm1_weights (torch.Tensor) –
[num_experts, 2 * intermediate_size, hidden_size // 2]packed FP4 FC1 weights,uint8.gemm1_weights_scale (torch.Tensor) –
[num_experts, 2 * intermediate_size, hidden_size // (32 if mxfp4 else 16)]FC1 weight block scales, float8.gemm1_bias (Optional[torch.Tensor]) –
[num_experts, 2 * intermediate_size]FC1 bias, float32.gemm1_alpha (Optional[torch.Tensor]) –
[num_experts]swiglu alpha, float32.gemm1_beta (Optional[torch.Tensor]) –
[num_experts]swiglu beta, float32.gemm1_clamp_limit (Optional[torch.Tensor]) –
[num_experts]swiglu clamp limit, float32.gemm2_weights (torch.Tensor) –
[num_experts, hidden_size, intermediate_size]packed FP4 FC2 weights,uint8.gemm2_weights_scale (torch.Tensor) –
[num_experts, hidden_size, intermediate_size // (32 if mxfp4 else 16)]FC2 weight block scales, float8.gemm2_bias (Optional[torch.Tensor]) –
[num_experts, hidden_size]FC2 bias, float32.output1_scale_scalar (Optional[torch.Tensor]) –
[local_num_experts]scaling factors for the first-layer activation output.output1_scale_gate_scalar (Optional[torch.Tensor]) –
[local_num_experts]scaling factors for the first-layer gate output.output2_scale_scalar (Optional[torch.Tensor]) –
[local_num_experts]scaling factors for the second-layer output.num_experts (int) – Total number of experts.
top_k (int) – Number of experts to route to per token.
n_group (Optional[int]) – Number of expert groups.
topk_group (Optional[int]) – Number of groups to consider for top-k routing.
intermediate_size (int) – Size of the intermediate layer.
local_expert_offset (int) – Offset of local experts in the global expert space.
local_num_experts (int) – Number of experts handled by this device.
routed_scaling_factor (Optional[float]) – Scaling factor for routing.
routing_method_type (int) –
Routing method (default
0). Selects the routing-kernel pipeline; matchesflashinfer.tllm_enums.RoutingMethodType.0Default— Softmax → TopK.1Renormalize— TopK → Softmax.2DeepSeekV3— Sigmoid → RoutingBiasAdd → Top-2 in group → Top-topk_groupgroups → Top-top_kexperts from the selected groups.3Llama4— Top-1 → Sigmoid.4RenormalizeNaive— Softmax → TopK → Renormalize (Qwen3 style).5TopK— TopK only (no softmax/sigmoid).6SigmoidRenorm— Sigmoid → TopK → Renormalize (divide by the sum of the top-K weights).7MiniMax2— Sigmoid + Bias → TopK → ScaledSumNormalize (routeScale = 1.0,epsilon = 1e-20).8Sigmoid— Sigmoid → TopK (no renormalization).9Unspecified— reserved.
do_finalize (bool) – Whether to finalize the output (default
True).enable_pdl (Optional[bool]) – Whether to enable Programmatic Dependent Launch.
activation_type (int) – Activation type (default
3— Swiglu).per_token_scale (Optional[torch.Tensor]) –
[seq_len]per-token scaling factors, float32.output (Optional[torch.Tensor]) – Optional in-place
[seq_len, hidden_size]output tensor.tune_max_num_tokens (int) – Maximum number of tokens for autotuning (default
8192).
- Returns:
[output]whendo_finalizeisTrue, otherwise[gemm2_output, expert_weights, expanded_idx_to_permuted_idx].- Return type:
List[torch.Tensor]