flashinfer.fused_moe¶
This module provides fused Mixture-of-Experts (MoE) operations optimized for different backends and data types.
Types and Enums¶
|
|
|
Utility Functions¶
|
Reshape a 2-D tensor into a 3-D block layout. |
Reorder rows of a weight tensor for the TensorRT-LLM gated-activation GEMM layout. |
|
Interleave 4-bit packed MoE weights for the SM90 mixed-input GEMM. |
|
Interleave MXFP4 block scales for the SM90 mixed-input MoE GEMM. |
|
|
Fused expert routing with top-k selection for DeepSeek-V3. |
Multi-LoRA MoE (BGMV)¶
Batched Gather-Matrix-Vector kernels for serving multiple LoRA adapters on top of a Mixture-of-Experts layer (shrink + expand).
|
High-level multi-LoRA MoE BGMV: shrink + expand in one call. |
|
MoE LoRA shrink operation: project input through LoRA-A matrices. |
|
MoE LoRA expand operation: project through LoRA-B matrices with routing weights. |
CUTLASS Fused MoE¶
|
Compute a Mixture of Experts (MoE) layer using CUTLASS backend. |
TensorRT-LLM Fused MoE¶
|
BF16 MoE operation with autotuning support. |
|
Pre-routed BF16 MoE operation with autotuning support. |
|
FP4 block-scaled MoE operation. |
|
FP4 block scale MoE operation with pre-computed routing. |
|
FP8 block-scaled MoE operation. |
|
Pre-routed FP8 block-scaled MoE operation. |
|
FP8 per-tensor-scale MoE operation. |
|
MXINT4 block-scaled MoE operation. |
|
MxInt4 block-scale MoE with pre-computed routing. |
CuteDSL Fused MoE¶
The CuteDSL backends are conditionally available when the
nvidia-cutlass-dsl package is installed.
|
Run a fused MoE forward pass using the CuTe-DSL NVFP4 kernels. |
|
Run fused MoE on SM120/SM121 using b12x CuTe-DSL kernels. |
- class flashinfer.fused_moe.CuteDslMoEWrapper(num_experts: int, top_k: int, hidden_size: int, intermediate_size: int, use_cuda_graph: bool = False, max_num_tokens: int | None = None, num_local_experts: int | None = None, local_expert_offset: int = 0, tile_size: int = 128, sf_vec_size: int = 16, output_dtype: dtype = torch.bfloat16, device: str = 'cuda', enable_pdl: bool = True, activation: str = 'silu')¶
Bases:
objectWrapper class for CuteDSL MoE with CUDA graph and auto-tuning support.
With use_cuda_graph=True, the wrapper creates persistent CUDA stream and event resources outside graph capture, enabling async-memset / GEMM1 overlap during capture and replay. Auto-tuning is supported via the tactic parameter or autotune() context.
Supported architectures: SM100, SM103.
- num_experts¶
Total number of experts.
- top_k¶
Number of experts per token.
Hidden dimension size.
- intermediate_size¶
Intermediate dimension size.
- use_cuda_graph¶
Whether the wrapper holds persistent stream/event resources for CUDA graph capture.
- max_num_tokens¶
Deprecated; accepted for backwards compatibility but ignored.
- Example (CUDA Graph):
>>> moe = CuteDslMoEWrapper( ... num_experts=256, top_k=8, ... hidden_size=7168, intermediate_size=2048, ... use_cuda_graph=True, ... ) >>> # Warmup >>> for _ in range(3): ... output = moe.run(x, x_sf, topk_ids, topk_weights, w1, w1_sf, ...) >>> # Capture >>> g = torch.cuda.CUDAGraph() >>> with torch.cuda.graph(g): ... output = moe.run(x, x_sf, topk_ids, topk_weights, w1, w1_sf, ...) >>> # Replay >>> g.replay()
- Example (Auto-tuning):
>>> moe = CuteDslMoEWrapper(num_experts=256, top_k=8, ...) >>> # Run with auto-tuning >>> with autotune(True): ... output = moe.run(x, x_sf, topk_ids, topk_weights, w1, w1_sf, ...)
- __init__(num_experts: int, top_k: int, hidden_size: int, intermediate_size: int, use_cuda_graph: bool = False, max_num_tokens: int | None = None, num_local_experts: int | None = None, local_expert_offset: int = 0, tile_size: int = 128, sf_vec_size: int = 16, output_dtype: dtype = torch.bfloat16, device: str = 'cuda', enable_pdl: bool = True, activation: str = 'silu')¶
Configure the CuTe-DSL NVFP4 fused-MoE wrapper.
- Parameters:
num_experts (int) – Total number of experts.
top_k (int) – Number of experts routed to per token.
hidden_size (int) – Hidden dimension size.
intermediate_size (int) – Intermediate dimension size (after SwiGLU reduction).
use_cuda_graph (bool) – Create persistent CUDA stream/events for async-memset overlap. Required for CUDA graph capture, since streams and events must be created outside graph capture. Defaults to
False.max_num_tokens (Optional[int]) – Deprecated; accepted for backwards compatibility but ignored.
num_local_experts (Optional[int]) – Local experts for expert parallelism. Defaults to
num_experts.local_expert_offset (int) – Offset of local experts in the global expert space. Defaults to
0.tile_size (int) – Tile size for
moe_sort. Defaults to128.sf_vec_size (int) – Scale-factor vector size. Defaults to
16.output_dtype (torch.dtype) – Output dtype. Defaults to
torch.bfloat16.device (str) – Device on which to allocate buffers. Defaults to
"cuda".enable_pdl (bool) – Enable Programmatic Dependent Launch. Defaults to
True.
- get_valid_tactics() list¶
Return list of valid tactics for this MoE configuration.
- run(x: Tensor, x_sf: Tensor, token_selected_experts: Tensor, token_final_scales: Tensor, w1_weight: Tensor, w1_weight_sf: Tensor, w1_alpha: Tensor, fc2_input_scale: Tensor, w2_weight: Tensor, w2_weight_sf: Tensor, w2_alpha: Tensor, tactic: Tuple | None = None) Tensor¶
Run the CuTe-DSL NVFP4 fused-MoE forward pass.
CUDA-graph safe when the wrapper was constructed with
use_cuda_graph=True. Supports auto-tuning via thetacticargument or the surroundingautotune()context manager.- Parameters:
x (torch.Tensor) – NVFP4-quantized input of shape
[num_tokens, hidden_size // 2].x_sf (torch.Tensor) – Scale factors for
x.token_selected_experts (torch.Tensor) – Expert assignments of shape
[num_tokens, top_k].token_final_scales (torch.Tensor) – Routing weights of shape
[num_tokens, top_k].w1_weight (torch.Tensor) – GEMM1 weights (gate + up fused).
w1_weight_sf (torch.Tensor) – Scale factors for
w1_weight.w1_alpha (torch.Tensor) – Per-expert global scale for GEMM1.
fc2_input_scale (torch.Tensor) – Global scale for GEMM2 input quantization.
w2_weight (torch.Tensor) – GEMM2 weights (down projection).
w2_weight_sf (torch.Tensor) – Scale factors for
w2_weight.w2_alpha (torch.Tensor) – Per-expert global scale for GEMM2.
tactic (Optional[Tuple]) – Tactic tuple, or
Nonefor auto-selection via the runtime tuner.
- Returns:
Output tensor of shape
[num_tokens, hidden_size].- Return type:
torch.Tensor
- class flashinfer.fused_moe.B12xMoEWrapper(num_experts: int, top_k: int, hidden_size: int, intermediate_size: int, *, use_cuda_graph: bool = False, max_num_tokens: int = 4096, num_local_experts: int | None = None, output_dtype: dtype = torch.bfloat16, device: str = 'cuda', activation: str = 'silu', activation_precision: str = 'fp4', quant_mode: str | None = None, source_format: str = 'modelopt')¶
Bases:
objectB12x fused MoE wrapper for SM120/SM121 with CUDA graph support.
Pre-allocates workspace buffers for CUDA graph compatibility. Automatically selects micro/static/dynamic backend per call.
- Parameters:
num_experts – Total number of experts.
top_k – Number of experts per token.
hidden_size – Hidden dimension size.
intermediate_size – Intermediate size.
use_cuda_graph – Pre-allocate buffers for CUDA graph compatibility.
max_num_tokens – Maximum tokens (only for use_cuda_graph=True).
num_local_experts – Local experts for EP. Default: num_experts.
output_dtype – Output data type. Only torch.bfloat16 is currently supported. Default: torch.bfloat16.
device – Device for buffer allocation. Default: “cuda”.
activation – Activation function — “silu” or “relu2”. Default: “silu”.
activation_precision – Backward-compatible alias for quant_mode. “fp4” selects quant_mode=”nvfp4”; “bf16” selects quant_mode=”w4a16”.
quant_mode – Quantization mode, “nvfp4”/”w4a4” or “w4a16”. When set, this selects the backend and internal workspace family.
source_format – Source weight format for quant_mode=”w4a16”. Supports “modelopt” and “compressed_tensors”. Default: “modelopt”.
Example
>>> moe = B12xMoEWrapper(num_experts=256, top_k=8, ...) >>> output = moe.run(x=hidden_states_bf16, ...)
- __init__(num_experts: int, top_k: int, hidden_size: int, intermediate_size: int, *, use_cuda_graph: bool = False, max_num_tokens: int = 4096, num_local_experts: int | None = None, output_dtype: dtype = torch.bfloat16, device: str = 'cuda', activation: str = 'silu', activation_precision: str = 'fp4', quant_mode: str | None = None, source_format: str = 'modelopt')¶
Configure the b12x fused-MoE wrapper.
- Parameters:
num_experts (int) – Total number of experts.
top_k (int) – Number of experts routed to per token.
hidden_size (int) – Hidden dimension size.
intermediate_size (int) – Intermediate dimension size.
use_cuda_graph (bool) – If
True, pre-allocate workspace buffers sized formax_num_tokensso the wrapper can be captured into a CUDA graph. Defaults toFalse.max_num_tokens (int) – Maximum batch size, only used when
use_cuda_graph=True. Defaults to4096.num_local_experts (Optional[int]) – Number of local experts for expert parallelism. Defaults to
num_experts.output_dtype (torch.dtype) – Output dtype. Only
torch.bfloat16is currently supported.device (str) – Device on which to allocate workspace buffers. Defaults to
"cuda".activation (str) – Activation function —
"silu"(gated SwiGLU) or"relu2"(non-gated). Defaults to"silu".activation_precision (str) – Backward-compatible alias for
quant_mode."fp4"selectsquant_mode="nvfp4";"bf16"selectsquant_mode="w4a16".quant_mode (Optional[str]) – Quantization mode,
"nvfp4"/"w4a4"or"w4a16".source_format (str) – Source weight format for
quant_mode="w4a16"—"modelopt"(default) or"compressed_tensors".
- run(x: Tensor, w1_weight: Tensor, w1_weight_sf: Tensor, w2_weight: Tensor, w2_weight_sf: Tensor, token_selected_experts: Tensor, token_final_scales: Tensor, *, w1_alpha: Tensor, w2_alpha: Tensor, fc2_input_scale: Tensor | None = None) Tensor¶
Run the b12x fused-MoE forward pass.
- Parameters:
x (torch.Tensor) – Input activations of shape
[num_tokens, hidden_size],bfloat16.w1_weight (torch.Tensor) – FC1 weights, FP4-packed.
w1_weight_sf (torch.Tensor) – Scale factors for
w1_weight.w2_weight (torch.Tensor) – FC2 weights, FP4-packed.
w2_weight_sf (torch.Tensor) – Scale factors for
w2_weight.token_selected_experts (torch.Tensor) – Expert assignments of shape
[num_tokens, top_k].token_final_scales (torch.Tensor) – Routing weights of shape
[num_tokens, top_k].w1_alpha (torch.Tensor) – Per-expert global scale for FC1.
w2_alpha (torch.Tensor) – Per-expert global scale for FC2.
fc2_input_scale (Optional[torch.Tensor]) – Global scale for FC2 input quantization. Required for
quant_mode="nvfp4"; accepted but ignored for"w4a16".
- Returns:
Output tensor of shape
[num_tokens, hidden_size].- Return type:
torch.Tensor