flashinfer.fused_moe.trtllm_fp8_block_scale_routed_moe¶
- flashinfer.fused_moe.trtllm_fp8_block_scale_routed_moe(topk_ids: Tensor, routing_bias: Tensor | None, hidden_states: Tensor, hidden_states_scale: Tensor, gemm1_weights: Tensor, gemm1_weights_scale: Tensor, gemm2_weights: Tensor, gemm2_weights_scale: Tensor, num_experts: int, top_k: int, n_group: int | None, topk_group: int | None, intermediate_size: int, local_expert_offset: int, local_num_experts: int, routed_scaling_factor: float | None, routing_method_type: int = 0, use_shuffled_weight: bool = False, weight_layout: int = 0, do_finalize: bool = True, enable_pdl: bool | None = None, gemm1_lora_delta: Tensor | None = None, output: Tensor | None = None, tune_max_num_tokens: int = 8192, fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8, activation_type: int = 3) List[Tensor] | Tensor¶
Pre-routed FP8 block-scaled MoE operation.
Like
trtllm_fp8_block_scale_moe(), but consumes a pre-computed packed(expert_id, weight)tensor instead of routing logits. Use this entry point for CUDA-graph capture (avoids the CPU-GPU sync from logits processing) or distributed MoE where routing happens elsewhere.- Parameters:
topk_ids (torch.Tensor) –
[seq_len, top_k]int32 tensor of packed expert indices and weights with format(expert_id << 16) | (weight_bf16.view(int16)).routing_bias (Optional[torch.Tensor]) –
[num_experts]tensor of routing bias (may beNone).hidden_states (torch.Tensor) –
[seq_len, hidden_size]tensor of input hidden states.hidden_states_scale (torch.Tensor) –
[hidden_size // (32 if mxfp8 else 128), seq_len]block scales for the hidden states.gemm1_weights (torch.Tensor) –
[num_experts, M, hidden_size]first-layer weights whereMis2 * intermediate_sizefor gated activations andintermediate_sizefor non-gated.gemm1_weights_scale (torch.Tensor) –
[num_experts, 2*intermediate_size // (32 if mxfp8 else 128), hidden_size // (32 if mxfp8 else 128)]first-layer block scales.gemm2_weights (torch.Tensor) –
[num_experts, hidden_size, intermediate_size]second-layer weights.gemm2_weights_scale (torch.Tensor) –
[num_experts, hidden_size // (32 if mxfp8 else 128), intermediate_size // (32 if mxfp8 else 128)]second-layer block scales.num_experts (int) – Total number of experts.
top_k (int) – Number of experts to route to per token.
n_group (Optional[int]) – Number of expert groups.
topk_group (Optional[int]) – Number of groups to consider for top-k routing.
intermediate_size (int) – Size of the intermediate layer.
local_expert_offset (int) – Offset of local experts in the global expert space.
local_num_experts (int) – Number of experts handled by this device.
routed_scaling_factor (Optional[float]) – Scaling factor for routing.
routing_method_type (int) –
Routing method (default
0). Selects the routing-kernel pipeline; matchesflashinfer.tllm_enums.RoutingMethodType.0Default— Softmax → TopK.1Renormalize— TopK → Softmax.2DeepSeekV3— Sigmoid → RoutingBiasAdd → Top-2 in group → Top-topk_groupgroups → Top-top_kexperts from the selected groups.3Llama4— Top-1 → Sigmoid.4RenormalizeNaive— Softmax → TopK → Renormalize (Qwen3 style).5TopK— TopK only (no softmax/sigmoid).6SigmoidRenorm— Sigmoid → TopK → Renormalize (divide by the sum of the top-K weights).7MiniMax2— Sigmoid + Bias → TopK → ScaledSumNormalize (routeScale = 1.0,epsilon = 1e-20).8Sigmoid— Sigmoid → TopK (no renormalization).9Unspecified— reserved.
use_shuffled_weight (bool) – Whether to use the shuffled weight layout (default
False).weight_layout (int) –
Weight layout for
gemm1_weights/gemm2_weights; matchesflashinfer.tllm_enums.WeightLayout. Allowed values for this function depend onfp8_quantization_type:DeepSeekFp8acceptsMajorKorBlockMajorK;MxFp8requiresMajorK. Default0(MajorK).0MajorK— K-major, logical shape[Mn, K].1MajorMn— M-major (A) / N-major (B), logical shape[K, Mn]. Not supported by this function.2BlockMajorK— Blocked along K, logical shape[K / blockK, Mn, blockK](blockKis fixed at 128 B). Only valid when ``fp8_quantization_type`` is ``DeepSeekFp8``.
do_finalize (bool) – Whether to finalize the output (default
True).enable_pdl (Optional[bool]) – Whether to enable Programmatic Dependent Launch.
None(default) lets the runtime auto-select on SM90+.gemm1_lora_delta (Optional[torch.Tensor]) – Optional MoE LoRA delta of shape
[num_tokens, top_k, 2 * intermediate_size],bfloat16. When set for MXFP8 it is added to FC1 before the fused gated activation and the post-activation FC1 output is appended to the return list.output (Optional[torch.Tensor]) – Optional in-place output tensor of shape
[seq_len, hidden_size].tune_max_num_tokens (int) – Maximum number of tokens for autotuning (default
8192).fp8_quantization_type (Fp8QuantizationType) – FP8 quantization scheme (default
Fp8QuantizationType.DeepSeekFp8).activation_type (int) – Activation type (default
3— Swiglu).3Swiglu;4Geglu;6Relu2;7Identity.
- Returns:
Return shape depends on
do_finalizeandgemm1_lora_delta; seetrtllm_bf16_routed_moe()for the table.- Return type:
torch.Tensor or List[torch.Tensor]