flashinfer.fused_moe

This module provides fused Mixture-of-Experts (MoE) operations optimized for different backends and data types.

Types and Enums

RoutingMethodType(value[, names, module, ...])

WeightLayout(value[, names, module, ...])

Utility Functions

convert_to_block_layout(input_tensor, blockK)

Reshape a 2-D tensor into a 3-D block layout.

reorder_rows_for_gated_act_gemm(x)

Reorder rows of a weight tensor for the TensorRT-LLM gated-activation GEMM layout.

interleave_moe_weights_for_sm90_mixed_gemm(weight)

Interleave 4-bit packed MoE weights for the SM90 mixed-input GEMM.

interleave_moe_scales_for_sm90_mixed_gemm(scales)

Interleave MXFP4 block scales for the SM90 mixed-input MoE GEMM.

fused_topk_deepseek(scores, bias, n_group, ...)

Fused expert routing with top-k selection for DeepSeek-V3.

Multi-LoRA MoE (BGMV)

Batched Gather-Matrix-Vector kernels for serving multiple LoRA adapters on top of a Mixture-of-Experts layer (shrink + expand).

bgmv_moe(x, lora_a_weights, lora_b_weights, ...)

High-level multi-LoRA MoE BGMV: shrink + expand in one call.

bgmv_moe_shrink(y, x, w_ptr, ...)

MoE LoRA shrink operation: project input through LoRA-A matrices.

bgmv_moe_expand(y, x, w_ptr, ...)

MoE LoRA expand operation: project through LoRA-B matrices with routing weights.

CUTLASS Fused MoE

cutlass_fused_moe(input, ...[, ...])

Compute a Mixture of Experts (MoE) layer using CUTLASS backend.

TensorRT-LLM Fused MoE

trtllm_bf16_moe(routing_logits, ...[, ...])

BF16 MoE operation with autotuning support.

trtllm_bf16_routed_moe(topk_ids, ...[, ...])

Pre-routed BF16 MoE operation with autotuning support.

trtllm_fp4_block_scale_moe(routing_logits, ...)

FP4 block-scaled MoE operation.

trtllm_fp4_block_scale_routed_moe(topk_ids, ...)

FP4 block scale MoE operation with pre-computed routing.

trtllm_fp8_block_scale_moe(routing_logits, ...)

FP8 block-scaled MoE operation.

trtllm_fp8_block_scale_routed_moe(topk_ids, ...)

Pre-routed FP8 block-scaled MoE operation.

trtllm_fp8_per_tensor_scale_moe(...[, ...])

FP8 per-tensor-scale MoE operation.

trtllm_mxint4_block_scale_moe(...[, ...])

MXINT4 block-scaled MoE operation.

trtllm_mxint4_block_scale_routed_moe(...[, ...])

MxInt4 block-scale MoE with pre-computed routing.

CuteDSL Fused MoE

The CuteDSL backends are conditionally available when the nvidia-cutlass-dsl package is installed.

cute_dsl_fused_moe_nvfp4(x, x_sf, ...[, ...])

Run a fused MoE forward pass using the CuTe-DSL NVFP4 kernels.

b12x_fused_moe(x, w1_weight, w1_weight_sf, ...)

Run fused MoE on SM120/SM121 using b12x CuTe-DSL kernels.

class flashinfer.fused_moe.CuteDslMoEWrapper(num_experts: int, top_k: int, hidden_size: int, intermediate_size: int, use_cuda_graph: bool = False, max_num_tokens: int | None = None, num_local_experts: int | None = None, local_expert_offset: int = 0, tile_size: int = 128, sf_vec_size: int = 16, output_dtype: dtype = torch.bfloat16, device: str = 'cuda', enable_pdl: bool = True, activation: str = 'silu')

Bases: object

Wrapper class for CuteDSL MoE with CUDA graph and auto-tuning support.

With use_cuda_graph=True, the wrapper creates persistent CUDA stream and event resources outside graph capture, enabling async-memset / GEMM1 overlap during capture and replay. Auto-tuning is supported via the tactic parameter or autotune() context.

Supported architectures: SM100, SM103.

num_experts

Total number of experts.

top_k

Number of experts per token.

hidden_size

Hidden dimension size.

intermediate_size

Intermediate dimension size.

use_cuda_graph

Whether the wrapper holds persistent stream/event resources for CUDA graph capture.

max_num_tokens

Deprecated; accepted for backwards compatibility but ignored.

Example (CUDA Graph):
>>> moe = CuteDslMoEWrapper(
...     num_experts=256, top_k=8,
...     hidden_size=7168, intermediate_size=2048,
...     use_cuda_graph=True,
... )
>>> # Warmup
>>> for _ in range(3):
...     output = moe.run(x, x_sf, topk_ids, topk_weights, w1, w1_sf, ...)
>>> # Capture
>>> g = torch.cuda.CUDAGraph()
>>> with torch.cuda.graph(g):
...     output = moe.run(x, x_sf, topk_ids, topk_weights, w1, w1_sf, ...)
>>> # Replay
>>> g.replay()
Example (Auto-tuning):
>>> moe = CuteDslMoEWrapper(num_experts=256, top_k=8, ...)
>>> # Run with auto-tuning
>>> with autotune(True):
...     output = moe.run(x, x_sf, topk_ids, topk_weights, w1, w1_sf, ...)
__init__(num_experts: int, top_k: int, hidden_size: int, intermediate_size: int, use_cuda_graph: bool = False, max_num_tokens: int | None = None, num_local_experts: int | None = None, local_expert_offset: int = 0, tile_size: int = 128, sf_vec_size: int = 16, output_dtype: dtype = torch.bfloat16, device: str = 'cuda', enable_pdl: bool = True, activation: str = 'silu')

Configure the CuTe-DSL NVFP4 fused-MoE wrapper.

Parameters:
  • num_experts (int) – Total number of experts.

  • top_k (int) – Number of experts routed to per token.

  • hidden_size (int) – Hidden dimension size.

  • intermediate_size (int) – Intermediate dimension size (after SwiGLU reduction).

  • use_cuda_graph (bool) – Create persistent CUDA stream/events for async-memset overlap. Required for CUDA graph capture, since streams and events must be created outside graph capture. Defaults to False.

  • max_num_tokens (Optional[int]) – Deprecated; accepted for backwards compatibility but ignored.

  • num_local_experts (Optional[int]) – Local experts for expert parallelism. Defaults to num_experts.

  • local_expert_offset (int) – Offset of local experts in the global expert space. Defaults to 0.

  • tile_size (int) – Tile size for moe_sort. Defaults to 128.

  • sf_vec_size (int) – Scale-factor vector size. Defaults to 16.

  • output_dtype (torch.dtype) – Output dtype. Defaults to torch.bfloat16.

  • device (str) – Device on which to allocate buffers. Defaults to "cuda".

  • enable_pdl (bool) – Enable Programmatic Dependent Launch. Defaults to True.

get_valid_tactics() list

Return list of valid tactics for this MoE configuration.

run(x: Tensor, x_sf: Tensor, token_selected_experts: Tensor, token_final_scales: Tensor, w1_weight: Tensor, w1_weight_sf: Tensor, w1_alpha: Tensor, fc2_input_scale: Tensor, w2_weight: Tensor, w2_weight_sf: Tensor, w2_alpha: Tensor, tactic: Tuple | None = None) Tensor

Run the CuTe-DSL NVFP4 fused-MoE forward pass.

CUDA-graph safe when the wrapper was constructed with use_cuda_graph=True. Supports auto-tuning via the tactic argument or the surrounding autotune() context manager.

Parameters:
  • x (torch.Tensor) – NVFP4-quantized input of shape [num_tokens, hidden_size // 2].

  • x_sf (torch.Tensor) – Scale factors for x.

  • token_selected_experts (torch.Tensor) – Expert assignments of shape [num_tokens, top_k].

  • token_final_scales (torch.Tensor) – Routing weights of shape [num_tokens, top_k].

  • w1_weight (torch.Tensor) – GEMM1 weights (gate + up fused).

  • w1_weight_sf (torch.Tensor) – Scale factors for w1_weight.

  • w1_alpha (torch.Tensor) – Per-expert global scale for GEMM1.

  • fc2_input_scale (torch.Tensor) – Global scale for GEMM2 input quantization.

  • w2_weight (torch.Tensor) – GEMM2 weights (down projection).

  • w2_weight_sf (torch.Tensor) – Scale factors for w2_weight.

  • w2_alpha (torch.Tensor) – Per-expert global scale for GEMM2.

  • tactic (Optional[Tuple]) – Tactic tuple, or None for auto-selection via the runtime tuner.

Returns:

Output tensor of shape [num_tokens, hidden_size].

Return type:

torch.Tensor

class flashinfer.fused_moe.B12xMoEWrapper(num_experts: int, top_k: int, hidden_size: int, intermediate_size: int, *, use_cuda_graph: bool = False, max_num_tokens: int = 4096, num_local_experts: int | None = None, output_dtype: dtype = torch.bfloat16, device: str = 'cuda', activation: str = 'silu', activation_precision: str = 'fp4', quant_mode: str | None = None, source_format: str = 'modelopt')

Bases: object

B12x fused MoE wrapper for SM120/SM121 with CUDA graph support.

Pre-allocates workspace buffers for CUDA graph compatibility. Automatically selects micro/static/dynamic backend per call.

Parameters:
  • num_experts – Total number of experts.

  • top_k – Number of experts per token.

  • hidden_size – Hidden dimension size.

  • intermediate_size – Intermediate size.

  • use_cuda_graph – Pre-allocate buffers for CUDA graph compatibility.

  • max_num_tokens – Maximum tokens (only for use_cuda_graph=True).

  • num_local_experts – Local experts for EP. Default: num_experts.

  • output_dtype – Output data type. Only torch.bfloat16 is currently supported. Default: torch.bfloat16.

  • device – Device for buffer allocation. Default: “cuda”.

  • activation – Activation function — “silu” or “relu2”. Default: “silu”.

  • activation_precision – Backward-compatible alias for quant_mode. “fp4” selects quant_mode=”nvfp4”; “bf16” selects quant_mode=”w4a16”.

  • quant_mode – Quantization mode, “nvfp4”/”w4a4” or “w4a16”. When set, this selects the backend and internal workspace family.

  • source_format – Source weight format for quant_mode=”w4a16”. Supports “modelopt” and “compressed_tensors”. Default: “modelopt”.

Example

>>> moe = B12xMoEWrapper(num_experts=256, top_k=8, ...)
>>> output = moe.run(x=hidden_states_bf16, ...)
__init__(num_experts: int, top_k: int, hidden_size: int, intermediate_size: int, *, use_cuda_graph: bool = False, max_num_tokens: int = 4096, num_local_experts: int | None = None, output_dtype: dtype = torch.bfloat16, device: str = 'cuda', activation: str = 'silu', activation_precision: str = 'fp4', quant_mode: str | None = None, source_format: str = 'modelopt')

Configure the b12x fused-MoE wrapper.

Parameters:
  • num_experts (int) – Total number of experts.

  • top_k (int) – Number of experts routed to per token.

  • hidden_size (int) – Hidden dimension size.

  • intermediate_size (int) – Intermediate dimension size.

  • use_cuda_graph (bool) – If True, pre-allocate workspace buffers sized for max_num_tokens so the wrapper can be captured into a CUDA graph. Defaults to False.

  • max_num_tokens (int) – Maximum batch size, only used when use_cuda_graph=True. Defaults to 4096.

  • num_local_experts (Optional[int]) – Number of local experts for expert parallelism. Defaults to num_experts.

  • output_dtype (torch.dtype) – Output dtype. Only torch.bfloat16 is currently supported.

  • device (str) – Device on which to allocate workspace buffers. Defaults to "cuda".

  • activation (str) – Activation function — "silu" (gated SwiGLU) or "relu2" (non-gated). Defaults to "silu".

  • activation_precision (str) – Backward-compatible alias for quant_mode. "fp4" selects quant_mode="nvfp4"; "bf16" selects quant_mode="w4a16".

  • quant_mode (Optional[str]) – Quantization mode, "nvfp4" / "w4a4" or "w4a16".

  • source_format (str) – Source weight format for quant_mode="w4a16""modelopt" (default) or "compressed_tensors".

run(x: Tensor, w1_weight: Tensor, w1_weight_sf: Tensor, w2_weight: Tensor, w2_weight_sf: Tensor, token_selected_experts: Tensor, token_final_scales: Tensor, *, w1_alpha: Tensor, w2_alpha: Tensor, fc2_input_scale: Tensor | None = None) Tensor

Run the b12x fused-MoE forward pass.

Parameters:
  • x (torch.Tensor) – Input activations of shape [num_tokens, hidden_size], bfloat16.

  • w1_weight (torch.Tensor) – FC1 weights, FP4-packed.

  • w1_weight_sf (torch.Tensor) – Scale factors for w1_weight.

  • w2_weight (torch.Tensor) – FC2 weights, FP4-packed.

  • w2_weight_sf (torch.Tensor) – Scale factors for w2_weight.

  • token_selected_experts (torch.Tensor) – Expert assignments of shape [num_tokens, top_k].

  • token_final_scales (torch.Tensor) – Routing weights of shape [num_tokens, top_k].

  • w1_alpha (torch.Tensor) – Per-expert global scale for FC1.

  • w2_alpha (torch.Tensor) – Per-expert global scale for FC2.

  • fc2_input_scale (Optional[torch.Tensor]) – Global scale for FC2 input quantization. Required for quant_mode="nvfp4"; accepted but ignored for "w4a16".

Returns:

Output tensor of shape [num_tokens, hidden_size].

Return type:

torch.Tensor