flashinfer.fused_moe.trtllm_bf16_moe¶
- flashinfer.fused_moe.trtllm_bf16_moe(routing_logits: Tensor, routing_bias: Tensor | None, hidden_states: Tensor, gemm1_weights: Tensor, gemm2_weights: Tensor, num_experts: int, top_k: int, n_group: int | None, topk_group: int | None, intermediate_size: int, local_expert_offset: int, local_num_experts: int, routed_scaling_factor: float | None = None, routing_method_type: int = 0, use_shuffled_weight: bool = True, weight_layout: int = WeightLayout.BlockMajorK, do_finalize: bool = True, enable_pdl: bool = True, tune_max_num_tokens: int = 8192, activation_type: int = 3, norm_topk_prob: bool = True, routing_replay_out: Tensor | None = None) List[Tensor] | Tensor¶
BF16 MoE operation with autotuning support.
Implements a bfloat16 Mixture of Experts layer using the TensorRT-LLM backend with automatic performance tuning for optimal tile-size selection.
- Parameters:
routing_logits (torch.Tensor) –
[seq_len, num_experts]tensor of routing logits.float32orbfloat16.routing_bias (Optional[torch.Tensor]) – Optional
[num_experts]tensor of routing bias. Must bebfloat16if provided.hidden_states (torch.Tensor) –
[seq_len, hidden_size]tensor of input hidden states. Must bebfloat16.gemm1_weights (torch.Tensor) –
[num_experts, M // 128, hidden_size // 128, 128]first-layer weights,bfloat16.Mequals2 * intermediate_sizefor gated activations andintermediate_sizefor non-gated activations.gemm2_weights (torch.Tensor) –
[num_experts, hidden_size // 128, intermediate_size, 128]second-layer weights,bfloat16.num_experts (int) – Total number of experts.
top_k (int) – Number of experts to route to per token.
n_group (Optional[int]) – Number of expert groups.
topk_group (Optional[int]) – Number of groups to consider for top-k routing.
intermediate_size (int) – Size of the intermediate layer.
local_expert_offset (int) – Offset of local experts in the global expert space.
local_num_experts (int) – Number of experts handled by this device.
routed_scaling_factor (Optional[float]) – Scaling factor for routing (may be
Nonefor some methods).routing_method_type (int) –
Routing method (default
0). Selects the routing-kernel pipeline; matchesflashinfer.tllm_enums.RoutingMethodType.0Default— Softmax → TopK.1Renormalize— TopK → Softmax.2DeepSeekV3— Sigmoid → RoutingBiasAdd → Top-2 in group → Top-topk_groupgroups → Top-top_kexperts from the selected groups.3Llama4— Top-1 → Sigmoid.4RenormalizeNaive— Softmax → TopK → Renormalize (Qwen3 style).5TopK— TopK only (no softmax/sigmoid).6SigmoidRenorm— Sigmoid → TopK → Renormalize (divide by the sum of the top-K weights).7MiniMax2— Sigmoid + Bias → TopK → ScaledSumNormalize (routeScale = 1.0,epsilon = 1e-20).8Sigmoid— Sigmoid → TopK (no renormalization).9Unspecified— reserved.
use_shuffled_weight (bool) – Whether to use the shuffled weight layout (default
True).weight_layout (int) –
Weight layout for
gemm1_weights/gemm2_weights; matchesflashinfer.tllm_enums.WeightLayout. This BF16 MoE entry point requiresBlockMajorK— passing any other value raises a runtime error. DefaultWeightLayout.BlockMajorK.0MajorK— K-major, logical shape[Mn, K]. Not supported by this function.1MajorMn— M-major (A) / N-major (B), logical shape[K, Mn]. Not supported by this function.2BlockMajorK— Blocked along K, logical shape[K / blockK, Mn, blockK](blockKis fixed at 128 B).
do_finalize (bool) – Whether to finalize the output (default
True).enable_pdl (bool) – Whether to enable Programmatic Dependent Launch. Auto-enabled for SM90+ when
True.tune_max_num_tokens (int) – Maximum number of tokens for autotuning (default
8192).activation_type (int) – Activation type (default
3— Swiglu).3Swiglu;6Relu2 (non-gated).norm_topk_prob (bool) – Whether to normalize the top-k probabilities (default
True).routing_replay_out (Optional[torch.Tensor]) – Optional
int16tensor of shape(num_tokens_or_larger, top_k)used to capture the selected expert IDs during routing. Column order matchestopk_indices. WhenNone(default) the kernel skips the write entirely. The buffer may be larger thannum_tokensfor CUDA-graph pre-allocation; only rows[0, num_tokens)are written.
- Returns:
If
do_finalizeisTruereturns the final MoE output (deprecated scalar return; will become[output]in v0.8.0). Otherwise returns[gemm2_output, expert_weights, expanded_idx_to_permuted_idx].- Return type:
torch.Tensor or List[torch.Tensor]