flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe¶
- flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe(routing_logits: Tensor, routing_bias: Tensor | None, hidden_states: Tensor, gemm1_weights: Tensor, output1_scales_scalar: Tensor, output1_scales_gate_scalar: Tensor, gemm2_weights: Tensor, output2_scales_scalar: Tensor, num_experts: int, top_k: int, n_group: int | None, topk_group: int | None, intermediate_size: int, local_expert_offset: int, local_num_experts: int, routed_scaling_factor: float | None, use_routing_scales_on_input: bool, routing_method_type: int = 0, do_finalize: bool = True, enable_pdl: bool | None = None, tune_max_num_tokens: int = 8192, activation_type: int = 3, norm_topk_prob: bool = True, routing_replay_out: Tensor | None = None) List[Tensor] | Tensor¶
FP8 per-tensor-scale MoE operation.
- Parameters:
routing_logits (torch.Tensor) –
[seq_len, num_experts]tensor of routing logits.routing_bias (Optional[torch.Tensor]) –
[num_experts]tensor of routing bias.hidden_states (torch.Tensor) –
[seq_len, hidden_size]tensor of input hidden states.gemm1_weights (torch.Tensor) –
[num_experts, M, hidden_size]first-layer weights.Mis2 * intermediate_sizefor gated activations andintermediate_sizefor non-gated activations.output1_scales_scalar (torch.Tensor) –
[local_num_experts]first-layer output scales.output1_scales_gate_scalar (torch.Tensor) –
[local_num_experts]first-layer gate scales.gemm2_weights (torch.Tensor) –
[num_experts, hidden_size, intermediate_size]second-layer weights.output2_scales_scalar (torch.Tensor) –
[local_num_experts]second-layer output scales.num_experts (int) – Total number of experts.
top_k (int) – Number of experts to route to per token.
n_group (Optional[int]) – Number of expert groups.
topk_group (Optional[int]) – Number of groups to consider for top-k routing.
intermediate_size (int) – Size of the intermediate layer.
local_expert_offset (int) – Offset of local experts in the global expert space.
local_num_experts (int) – Number of experts handled by this device.
routed_scaling_factor (Optional[float]) – Scaling factor for routing.
use_routing_scales_on_input (bool) – Whether to use routing scales on input.
routing_method_type (int) –
Routing method (default
0). Selects the routing-kernel pipeline; matchesflashinfer.tllm_enums.RoutingMethodType.0Default— Softmax → TopK.1Renormalize— TopK → Softmax.2DeepSeekV3— Sigmoid → RoutingBiasAdd → Top-2 in group → Top-topk_groupgroups → Top-top_kexperts from the selected groups.3Llama4— Top-1 → Sigmoid.4RenormalizeNaive— Softmax → TopK → Renormalize (Qwen3 style).5TopK— TopK only (no softmax/sigmoid).6SigmoidRenorm— Sigmoid → TopK → Renormalize (divide by the sum of the top-K weights).7MiniMax2— Sigmoid + Bias → TopK → ScaledSumNormalize (routeScale = 1.0,epsilon = 1e-20).8Sigmoid— Sigmoid → TopK (no renormalization).9Unspecified— reserved.
do_finalize (bool) – Whether to finalize the output (default
True).enable_pdl (Optional[bool]) – Whether to enable Programmatic Dependent Launch.
None(default) lets the runtime auto-select on SM90+.tune_max_num_tokens (int) – Maximum number of tokens for autotuning (default
8192).activation_type (int) – Activation type (default
3— Swiglu).0Gelu;3Swiglu;4Geglu;6Relu2 (non-gated);7Identity.norm_topk_prob (bool) – Whether to normalize the top-k probabilities (default
True).routing_replay_out (Optional[torch.Tensor]) – Optional
int16tensor of shape(num_tokens_or_larger, top_k)used to capture the selected expert IDs during routing. Column order matchestopk_indices. WhenNone(default) the kernel skips the write entirely. The buffer may be larger thannum_tokensfor CUDA-graph pre-allocation; only rows[0, num_tokens)are written.
- Returns:
Final MoE output when
do_finalizeisTrue, otherwise[gemm2_output, expert_weights, expanded_idx_to_permuted_idx].- Return type:
torch.Tensor or List[torch.Tensor]