flashinfer.comm.trtllm_moe_finalize_allreduce_fusion¶
- flashinfer.comm.trtllm_moe_finalize_allreduce_fusion(allreduce_in: Tensor, residual_in: Tensor, norm_weight: Tensor, expanded_idx_to_permuted_idx: Tensor, norm_out: Tensor, residual_out: Tensor, workspace_ptrs: Tensor, launch_with_pdl: bool, world_rank: int, world_size: int, eps: float, shared_expert_output: Tensor | None, expert_scale_factor: Tensor | None) None ¶
Parameters: - allreduce_in: the input tensor. [token_num, top_k, hidden_dim] - residual_in: the residual input tensor. [token_num, hidden_dim] - norm_weight: the norm weight tensor. [hidden_dim] - expanded_idx_to_permuted_idx: the expanded index to permuted index tensor. [token_num, top_k] - norm_out: the norm output tensor. [token_num, hidden_dim] - residual_out: the residual output tensor. [token_num, hidden_dim] - workspace_ptrs: the workspace pointers. - launch_with_pdl: whether to launch with pdl. - world_rank: the rank of the current process. - world_size: the size of the process group. - eps: the epsilon value. - shared_expert_output: the shared expert output tensor. [token_num, hidden_dim] - expert_scale_factor: the expert scale factor tensor. [token_num, top_k]