flashinfer.comm.trtllm_moe_allreduce_fusion¶
- flashinfer.comm.trtllm_moe_allreduce_fusion(world_size: int, world_rank: int, token_num: int, hidden_dim: int, workspace_ptrs: Tensor, launch_with_pdl: bool, residual_in: Tensor, rms_gamma: Tensor, rms_eps: float, scale_factor: float, moe_reduction_device_num_experts: int, moe_reduction_scale_input: Tensor, moe_reduction_active_experts_token_input: Tensor, moe_reduction_token_input: Tensor, layout_code: QuantizationSFLayout | None, moe_allreduce_out: Tensor | None, residual_out: Tensor | None, norm_out: Tensor | None, quant_out: Tensor | None, scale_out: Tensor | None) None ¶
Parameters: - world_size: the size of the process group. - world_rank: the rank of the current process. - token_num: the number of tokens in the sequence. - hidden_dim: the dimension of the hidden states. - workspace_ptrs: the workspace pointers. - launch_with_pdl: whether to launch with pdl. - residual_in: the residual input tensor. [token_num, hidden_dim] - rms_gamma: the rms gamma tensor. [hidden_dim] - rms_eps: the rms epsilon value. - scale_factor: the scale factor. - moe_reduction_device_num_experts: the number of experts. - moe_reduction_scale_input: the scale input tensor. [token_num, hidden_dim] - moe_reduction_active_experts_token_input: the active experts token input tensor. [token_num, hidden_dim] - moe_reduction_token_input: the token input tensor. [token_num, hidden_dim] - layout_code: the layout code. - moe_allreduce_out: the moe allreduce output tensor. [token_num, hidden_dim] - residual_out: the residual output tensor. [token_num, hidden_dim] - norm_out: the norm output tensor. [token_num, hidden_dim] - quant_out: the quant output tensor. [token_num // 4, hidden_dim], fp16/bf16 -> fp4 - scale_out: the scale output tensor. Initialization referece: tests/comm/test_trtllm_moe_allreduce_fusion.py