flashinfer.fused_moe.cutlass_fused_moe¶
- flashinfer.fused_moe.cutlass_fused_moe(input: torch.Tensor, token_selected_experts: torch.Tensor, token_final_scales: torch.Tensor, fc1_expert_weights: torch.Tensor, fc2_expert_weights: torch.Tensor, output_dtype: torch.dtype, quant_scales: List[torch.Tensor], fc1_expert_biases: torch.Tensor | None = None, fc2_expert_biases: torch.Tensor | None = None, input_sf: torch.Tensor | None = None, tp_size: int = 1, tp_rank: int = 0, ep_size: int = 1, ep_rank: int = 0, cluster_size: int = 1, cluster_rank: int = 0, output: torch.Tensor | None = None, enable_alltoall: bool = False, use_deepseek_fp8_block_scale: bool = False, use_w4a8_group_scaling: bool = False, use_mxfp8_act_scaling: bool = False, min_latency_mode: bool = False, tune_max_num_tokens: int = 8192) torch.Tensor ¶
Compute a Mixture of Experts (MoE) layer using CUTLASS backend.
This function implements a fused MoE layer that combines expert selection, expert computation, and output combination into a single operation. It uses CUTLASS for efficient matrix multiplication and supports various data types and parallelism strategies.
- Parameters:
input (torch.Tensor) – Input tensor of shape [num_tokens, hidden_size]. Support float, float16, bfloat16, float8_e4m3fn and nvfp4. For FP8, the input must be quantized. For NVFP4, both quantized and non-quantized inputs are supported.
token_selected_experts (torch.Tensor) – Indices of selected experts for each token.
token_final_scales (torch.Tensor) – Scaling factors for each token’s expert outputs.
fc1_expert_weights (torch.Tensor) – GEMM1 weights for each expert.
fc2_expert_weights (torch.Tensor) – GEMM2 weights for each expert.
output_dtype (torch.dtype) – Desired output data type.
quant_scales (List[torch.Tensor]) –
Quantization scales for the operation.
- NVFP4:
gemm1 activation global scale
gemm1 weights block scales
gemm1 dequant scale
gemm2 activation global scale
gemm2 weights block scales
gemm2 dequant scale
- FP8:
gemm1 dequant scale
gemm2 activation quant scale
gemm2 dequant scale
gemm1 input dequant scale
fc1_expert_biases (Optional[torch.Tensor]) – GEMM1 biases for each expert.
fc2_expert_biases (Optional[torch.Tensor]) – GEMM1 biases for each expert.
input_sf (Optional[torch.Tensor]) – Input scaling factor for quantization.
tp_size (int = 1) – Tensor parallelism size. Defaults to 1.
tp_rank (int = 0) – Tensor parallelism rank. Defaults to 0.
ep_size (int = 1) – Expert parallelism size. Defaults to 1.
ep_rank (int = 0) – Expert parallelism rank. Defaults to 0.
cluster_size (int = 1) – Cluster size. Defaults to 1.
cluster_rank (int = 0) – Cluster rank. Defaults to 0.
output (Optional[torch.Tensor] = None) – The output tensor, if not provided, will be allocated internally.
enable_alltoall (bool = False) – Whether to enable all-to-all communication for expert outputs. Defaults to False.
use_deepseek_fp8_block_scale (bool = False) – Whether to use FP8 block scaling. Defaults to False.
use_w4a8_group_scaling (bool = False) – Whether to use W4A8 group scaling. Defaults to False.
use_mxfp8_act_scaling (bool = False) – Whether to use MXFP8 activation scaling. Defaults to False.
min_latency_mode (bool = False) – Whether to use minimum latency mode. Defaults to False.
tune_max_num_tokens (int = 8192) – Maximum number of tokens for tuning. Defaults to 8192.
- Returns:
out – Output tensor of shape [seq_len, hidden_size].
- Return type:
torch.Tensor
- Raises:
NotImplementedError: – If any of the following features are requested but not implemented: - FP8 Block Scaling - W4A8 Group Scaling - Minimum Latency Mode
Note
The function supports various data types including FP32, FP16, BF16, FP8, and NVFP4.
It implements both tensor parallelism and expert parallelism.
- Currently, some advanced features like FP8 block scaling and minimum latency mode
are not implemented for Blackwell architecture.