flashinfer.fused_moe.trtllm_fp8_block_scale_routed_moe

flashinfer.fused_moe.trtllm_fp8_block_scale_routed_moe(topk_ids: Tensor, routing_bias: Tensor | None, hidden_states: Tensor, hidden_states_scale: Tensor, gemm1_weights: Tensor, gemm1_weights_scale: Tensor, gemm2_weights: Tensor, gemm2_weights_scale: Tensor, num_experts: int, top_k: int, n_group: int | None, topk_group: int | None, intermediate_size: int, local_expert_offset: int, local_num_experts: int, routed_scaling_factor: float | None, routing_method_type: int = 0, use_shuffled_weight: bool = False, weight_layout: int = 0, do_finalize: bool = True, enable_pdl: bool | None = None, gemm1_lora_delta: Tensor | None = None, output: Tensor | None = None, tune_max_num_tokens: int = 8192, fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8, activation_type: int = 3) List[Tensor] | Tensor

Pre-routed FP8 block-scaled MoE operation.

Like trtllm_fp8_block_scale_moe(), but consumes a pre-computed packed (expert_id, weight) tensor instead of routing logits. Use this entry point for CUDA-graph capture (avoids the CPU-GPU sync from logits processing) or distributed MoE where routing happens elsewhere.

Parameters:
  • topk_ids (torch.Tensor) – [seq_len, top_k] int32 tensor of packed expert indices and weights with format (expert_id << 16) | (weight_bf16.view(int16)).

  • routing_bias (Optional[torch.Tensor]) – [num_experts] tensor of routing bias (may be None).

  • hidden_states (torch.Tensor) – [seq_len, hidden_size] tensor of input hidden states.

  • hidden_states_scale (torch.Tensor) – [hidden_size // (32 if mxfp8 else 128), seq_len] block scales for the hidden states.

  • gemm1_weights (torch.Tensor) – [num_experts, M, hidden_size] first-layer weights where M is 2 * intermediate_size for gated activations and intermediate_size for non-gated.

  • gemm1_weights_scale (torch.Tensor) – [num_experts, 2*intermediate_size // (32 if mxfp8 else 128), hidden_size // (32 if mxfp8 else 128)] first-layer block scales.

  • gemm2_weights (torch.Tensor) – [num_experts, hidden_size, intermediate_size] second-layer weights.

  • gemm2_weights_scale (torch.Tensor) – [num_experts, hidden_size // (32 if mxfp8 else 128), intermediate_size // (32 if mxfp8 else 128)] second-layer block scales.

  • num_experts (int) – Total number of experts.

  • top_k (int) – Number of experts to route to per token.

  • n_group (Optional[int]) – Number of expert groups.

  • topk_group (Optional[int]) – Number of groups to consider for top-k routing.

  • intermediate_size (int) – Size of the intermediate layer.

  • local_expert_offset (int) – Offset of local experts in the global expert space.

  • local_num_experts (int) – Number of experts handled by this device.

  • routed_scaling_factor (Optional[float]) – Scaling factor for routing.

  • routing_method_type (int) –

    Routing method (default 0). Selects the routing-kernel pipeline; matches flashinfer.tllm_enums.RoutingMethodType.

    • 0 Default — Softmax → TopK.

    • 1 Renormalize — TopK → Softmax.

    • 2 DeepSeekV3 — Sigmoid → RoutingBiasAdd → Top-2 in group → Top-topk_group groups → Top-top_k experts from the selected groups.

    • 3 Llama4 — Top-1 → Sigmoid.

    • 4 RenormalizeNaive — Softmax → TopK → Renormalize (Qwen3 style).

    • 5 TopK — TopK only (no softmax/sigmoid).

    • 6 SigmoidRenorm — Sigmoid → TopK → Renormalize (divide by the sum of the top-K weights).

    • 7 MiniMax2 — Sigmoid + Bias → TopK → ScaledSumNormalize (routeScale = 1.0, epsilon = 1e-20).

    • 8 Sigmoid — Sigmoid → TopK (no renormalization).

    • 9 Unspecified — reserved.

  • use_shuffled_weight (bool) – Whether to use the shuffled weight layout (default False).

  • weight_layout (int) –

    Weight layout for gemm1_weights / gemm2_weights; matches flashinfer.tllm_enums.WeightLayout. Allowed values for this function depend on fp8_quantization_type: DeepSeekFp8 accepts MajorK or BlockMajorK; MxFp8 requires MajorK. Default 0 (MajorK).

    • 0 MajorK — K-major, logical shape [Mn, K].

    • 1 MajorMn — M-major (A) / N-major (B), logical shape [K, Mn]. Not supported by this function.

    • 2 BlockMajorK — Blocked along K, logical shape [K / blockK, Mn, blockK] (blockK is fixed at 128 B). Only valid when ``fp8_quantization_type`` is ``DeepSeekFp8``.

  • do_finalize (bool) – Whether to finalize the output (default True).

  • enable_pdl (Optional[bool]) – Whether to enable Programmatic Dependent Launch. None (default) lets the runtime auto-select on SM90+.

  • gemm1_lora_delta (Optional[torch.Tensor]) – Optional MoE LoRA delta of shape [num_tokens, top_k, 2 * intermediate_size], bfloat16. When set for MXFP8 it is added to FC1 before the fused gated activation and the post-activation FC1 output is appended to the return list.

  • output (Optional[torch.Tensor]) – Optional in-place output tensor of shape [seq_len, hidden_size].

  • tune_max_num_tokens (int) – Maximum number of tokens for autotuning (default 8192).

  • fp8_quantization_type (Fp8QuantizationType) – FP8 quantization scheme (default Fp8QuantizationType.DeepSeekFp8).

  • activation_type (int) – Activation type (default 3 — Swiglu). 3 Swiglu; 4 Geglu; 6 Relu2; 7 Identity.

Returns:

Return shape depends on do_finalize and gemm1_lora_delta; see trtllm_bf16_routed_moe() for the table.

Return type:

torch.Tensor or List[torch.Tensor]