flashinfer.fused_moe.trtllm_fp4_block_scale_moe

flashinfer.fused_moe.trtllm_fp4_block_scale_moe(routing_logits: torch.Tensor, routing_bias: torch.Tensor | None, hidden_states: torch.Tensor, hidden_states_scale: torch.Tensor | None, gemm1_weights: torch.Tensor, gemm1_weights_scale: torch.Tensor, gemm1_bias: torch.Tensor | None, gemm1_alpha: torch.Tensor | None, gemm1_beta: torch.Tensor | None, gemm1_clamp_limit: torch.Tensor | None, gemm2_weights: torch.Tensor, gemm2_weights_scale: torch.Tensor, gemm2_bias: torch.Tensor | None, output1_scale_scalar: torch.Tensor | None, output1_scale_gate_scalar: torch.Tensor | None, output2_scale_scalar: torch.Tensor | None, num_experts: int, top_k: int, n_group: int | None, topk_group: int | None, intermediate_size: int, local_expert_offset: int, local_num_experts: int, routed_scaling_factor: float | None, tile_tokens_dim: int = 8, routing_method_type: int = 0, do_finalize: bool = True, output: torch.Tensor | None = None) List[torch.Tensor]

FP4 block scale MoE operation.

Parameters:
  • routing_logits (torch.Tensor) – shape [seq_len, num_experts] Input tensor of routing logits. Supports float32, bfloat16.

  • routing_bias (Optional[torch.Tensor]) – shape [num_experts] Tensor of routing bias. Can be None for some routing methods. Must be the same type as routing logits.

  • hidden_states (torch.Tensor) – shape [seq_len, hidden_size // 2 if nvfp4 else hidden_size] Tensor of input hidden states. Supports bfloat16, mxfp8, and nvfp4 (packed into uint8)

  • hidden_states_scale (Optional[torch.Tensor]) – shape [seq_len, hidden_size // (32 if mxfp8, 16 if mxfp4)] Scale tensor of mxfp8 / nvfp4 hidden states. Dtype must be float8.

  • gemm1_weights (torch.Tensor) – shape [num_experts, 2 * intermediate_size, hidden_size // 2] Tensor of FC1 weights. Dtype must be uint8 (packed fp4)

  • gemm1_weights_scale (torch.Tensor) – shape [num_experts, 2 * intermediate_size, hidden_size // (32 if mxfp4 else 16)] Scale tensor of FC1 weights. Dtype must be float8.

  • gemm2_weights (torch.Tensor) – shape [num_experts, hidden_size, intermediate_size] Tensor of FC2 weights. Dtype must be uint8 (packed fp4)

  • gemm2_weights_scale (torch.Tensor) – shape [num_experts, hidden_size//128, intermediate_size//128] Scale tensor of FC2 weights. Dtype must be float8.

  • output1_scale_scalar (Optional[torch.Tensor]) – shape [local_num_experts] Tensor of scaling factors for first layer activation output

  • output1_scale_gate_scalar (Optional[torch.Tensor]) – shape [local_num_experts] Tensor of scaling factors for first layer gate output

  • output2_scale_scalar (Optional[torch.Tensor]) – shape [local_num_experts] Tensor of scaling factors for second layer output

  • num_experts (int) – Total number of experts

  • top_k (int) – Number of experts to route to per token

  • n_group (Optional[int]) – Number of expert groups (can be None for some routing methods)

  • topk_group (Optional[int]) – Number of groups to consider for top-k routing (can be None for some routing methods)

  • intermediate_size (int) – Size of intermediate layer

  • local_expert_offset (int) – Offset of local experts in global expert space

  • local_num_experts (int) – Number of experts handled by this device

  • routed_scaling_factor (Optional[float]) – Scaling factor for routing (can be None for some routing methods)

  • tile_tokens_dim (int) – Tile dimension for tokens (default: 8)

  • routing_method_type (int) – Type of routing method to use (default: 0) - 0: Default (Softmax -> TopK) - 1: Renormalize (TopK -> Softmax) - 2: DeepSeekV3 (Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups -> Top8 experts) - 3: Llama4 (Top1 -> Sigmoid) - 4: RenormalizeNaive (Softmax -> TopK -> Renormalize)

  • do_finalize (bool) – Whether to finalize the output (default: False)

  • output (Optional[torch.Tensor]) – shape [seq_len, hidden_size] Optional inplace output tensor.

Returns:

List of output tensors. If do_finalize=True, returns the final MoE output.

Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing.

Return type:

List[torch.Tensor]