flashinfer.fused_moe.trtllm_fp4_block_scale_moe¶
- flashinfer.fused_moe.trtllm_fp4_block_scale_moe(routing_logits: torch.Tensor, routing_bias: torch.Tensor | None, hidden_states: torch.Tensor, hidden_states_scale: torch.Tensor | None, gemm1_weights: torch.Tensor, gemm1_weights_scale: torch.Tensor, gemm1_bias: torch.Tensor | None, gemm1_alpha: torch.Tensor | None, gemm1_beta: torch.Tensor | None, gemm1_clamp_limit: torch.Tensor | None, gemm2_weights: torch.Tensor, gemm2_weights_scale: torch.Tensor, gemm2_bias: torch.Tensor | None, output1_scale_scalar: torch.Tensor | None, output1_scale_gate_scalar: torch.Tensor | None, output2_scale_scalar: torch.Tensor | None, num_experts: int, top_k: int, n_group: int | None, topk_group: int | None, intermediate_size: int, local_expert_offset: int, local_num_experts: int, routed_scaling_factor: float | None, tile_tokens_dim: int = 8, routing_method_type: int = 0, do_finalize: bool = True, output: torch.Tensor | None = None) List[torch.Tensor] ¶
FP4 block scale MoE operation.
- Parameters:
routing_logits (torch.Tensor) – shape [seq_len, num_experts] Input tensor of routing logits. Supports float32, bfloat16.
routing_bias (Optional[torch.Tensor]) – shape [num_experts] Tensor of routing bias. Can be None for some routing methods. Must be the same type as routing logits.
hidden_states (torch.Tensor) – shape [seq_len, hidden_size // 2 if nvfp4 else hidden_size] Tensor of input hidden states. Supports bfloat16, mxfp8, and nvfp4 (packed into uint8)
hidden_states_scale (Optional[torch.Tensor]) – shape [seq_len, hidden_size // (32 if mxfp8, 16 if mxfp4)] Scale tensor of mxfp8 / nvfp4 hidden states. Dtype must be float8.
gemm1_weights (torch.Tensor) – shape [num_experts, 2 * intermediate_size, hidden_size // 2] Tensor of FC1 weights. Dtype must be uint8 (packed fp4)
gemm1_weights_scale (torch.Tensor) – shape [num_experts, 2 * intermediate_size, hidden_size // (32 if mxfp4 else 16)] Scale tensor of FC1 weights. Dtype must be float8.
gemm2_weights (torch.Tensor) – shape [num_experts, hidden_size, intermediate_size] Tensor of FC2 weights. Dtype must be uint8 (packed fp4)
gemm2_weights_scale (torch.Tensor) – shape [num_experts, hidden_size//128, intermediate_size//128] Scale tensor of FC2 weights. Dtype must be float8.
output1_scale_scalar (Optional[torch.Tensor]) – shape [local_num_experts] Tensor of scaling factors for first layer activation output
output1_scale_gate_scalar (Optional[torch.Tensor]) – shape [local_num_experts] Tensor of scaling factors for first layer gate output
output2_scale_scalar (Optional[torch.Tensor]) – shape [local_num_experts] Tensor of scaling factors for second layer output
num_experts (int) – Total number of experts
top_k (int) – Number of experts to route to per token
n_group (Optional[int]) – Number of expert groups (can be None for some routing methods)
topk_group (Optional[int]) – Number of groups to consider for top-k routing (can be None for some routing methods)
intermediate_size (int) – Size of intermediate layer
local_expert_offset (int) – Offset of local experts in global expert space
local_num_experts (int) – Number of experts handled by this device
routed_scaling_factor (Optional[float]) – Scaling factor for routing (can be None for some routing methods)
tile_tokens_dim (int) – Tile dimension for tokens (default: 8)
routing_method_type (int) – Type of routing method to use (default: 0) - 0: Default (Softmax -> TopK) - 1: Renormalize (TopK -> Softmax) - 2: DeepSeekV3 (Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups -> Top8 experts) - 3: Llama4 (Top1 -> Sigmoid) - 4: RenormalizeNaive (Softmax -> TopK -> Renormalize)
do_finalize (bool) – Whether to finalize the output (default: False)
output (Optional[torch.Tensor]) – shape [seq_len, hidden_size] Optional inplace output tensor.
- Returns:
- List of output tensors. If do_finalize=True, returns the final MoE output.
Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing.
- Return type:
List[torch.Tensor]