flashinfer.comm.trtllm_custom_all_reduce¶
- flashinfer.comm.trtllm_custom_all_reduce(inp: Tensor, out: Tensor, tp_size: int, tp_rank: int, token_num: int, fusion_op_code: AllReduceFusionOp, strategy_code: AllReduceStrategyType, config_code: AllReduceStrategyConfig, launch_with_pdl: bool, flag_value: int, peer_comm_buffer_ptrs: Tensor, peer_barrier_ptrs_in: Tensor, peer_barrier_ptrs_out: Tensor, bias: Tensor | None, residual: Tensor | None, weight: Tensor | None, weight_pre_residual_norm: Tensor | None, eps: float | None, intermediate_buffer: Tensor | None, lamport_peer_comm_buffer_ptrs_0: Tensor | None, lamport_peer_comm_buffer_ptrs_1: Tensor | None, lamport_peer_comm_buffer_ptrs_2: Tensor | None) None ¶
Parameters: - inp: the input tensor. [token_num, hidden_dim] - out: the output tensor. [token_num, hidden_dim] - tp_size: the size of the process group. - tp_rank: the rank of the current process. - token_num: the number of tokens in the sequence. - fusion_op_code: the fusion operation code. - strategy_code: the strategy code. - config_code: the config code. - launch_with_pdl: whether to launch with pdl. - flag_value: the flag value. - peer_comm_buffer_ptrs: the peer communication buffer pointers. - peer_barrier_ptrs_in: the peer barrier pointers in. - peer_barrier_ptrs_out: the peer barrier pointers out. - bias: the bias tensor. [hidden_dim] - residual: the residual tensor. [token_num, hidden_dim] - weight: the weight tensor. [hidden_dim] - weight_pre_residual_norm: the weight pre residual norm tensor. [hidden_dim] - eps: the epsilon value. - intermediate_buffer: the intermediate buffer tensor. - lamport_peer_comm_buffer_ptrs_0: the lamport peer communication buffer pointers 0. - lamport_peer_comm_buffer_ptrs_1: the lamport peer communication buffer pointers 1. - lamport_peer_comm_buffer_ptrs_2: the lamport peer communication buffer pointers 2.