flashinfer.comm.trtllm_moe_finalize_allreduce_fusion

flashinfer.comm.trtllm_moe_finalize_allreduce_fusion(allreduce_in: Tensor, residual_in: Tensor, norm_weight: Tensor, expanded_idx_to_permuted_idx: Tensor, norm_out: Tensor | None, residual_out: Tensor | None, quant_out: Tensor | None, scale_out: Tensor | None, workspace_ptrs: Tensor, launch_with_pdl: bool, world_rank: int, world_size: int, eps: float, shared_expert_output: Tensor | None, expert_scale_factor: Tensor | None, routed_scaling_factor: float | None, weight_bias: float | None = None) None

Parameters: - allreduce_in: the permuted/padded MoE expert output tensor.

Shape [num_permuted_rows, hidden_dim]. Rows are referenced by expanded_idx_to_permuted_idx; num_permuted_rows may be larger than token_num * top_k due to expert padding.

  • residual_in: the residual input tensor. [token_num, hidden_dim]

  • norm_weight: the norm weight tensor. [hidden_dim]

  • expanded_idx_to_permuted_idx: the expanded index to permuted index tensor. [token_num, top_k]

  • norm_out: the norm output tensor. [token_num, hidden_dim]

  • residual_out: the residual output tensor. [token_num, hidden_dim]

  • quant_out: the quant output tensor. [token_num // 4, hidden_dim], fp16/bf16 -> fp4

  • scale_out: the scale output tensor. [token_num // SF_VEC_SIZE, hidden_dim], fp16/bf16 -> fp4

  • workspace_ptrs: the workspace pointers.

  • launch_with_pdl: whether to launch with pdl.

  • world_rank: the rank of the current process.

  • world_size: the size of the process group.

  • eps: the epsilon value.

  • shared_expert_output: the shared expert output tensor. [token_num, hidden_dim]

  • expert_scale_factor: the expert scale factor tensor. [token_num, top_k]

  • routed_scaling_factor: the routed scaling factor.

  • weight_bias: bias added to rms_gamma before scaling.

    None or 0.0 -> standard RMSNorm (out = gamma * x * rsqrt(…)). 1.0 -> Gemma / Qwen3.5 RMSNorm (out = (1 + gamma) * x * rsqrt(…)).