flashinfer.comm.trtllm_allreduce_fusion

flashinfer.comm.trtllm_allreduce_fusion(allreduce_in: Tensor, world_size: int, world_rank: int, token_num: int, hidden_dim: int, workspace_ptrs: Tensor, launch_with_pdl: bool, trigger_completion_at_end: bool, fp32_acc: bool, pattern_code: AllReduceFusionPattern, use_oneshot: bool | None, allreduce_out: Tensor | None, residual_in: Tensor | None, residual_out: Tensor | None, norm_out: Tensor | None, quant_out: Tensor | None, scale_out: Tensor | None, rms_gamma: Tensor | None, rms_eps: float | None, scale_factor: Tensor | float | None, layout_code: QuantizationSFLayout | None) None

Parameters: - allreduce_in: the input tensor. [token_num, hidden_dim] - world_size: the size of the process group. - world_rank: the rank of the current process. - token_num: the number of tokens in the sequence. - hidden_dim: the dimension of the hidden states. - workspace_ptrs: the workspace pointers. - launch_with_pdl: whether to launch with pdl. - use_oneshot: whether to use oneshot. - trigger_completion_at_end: whether to trigger completion at the end. - fp32_acc: whether to use fp32 accumulation. - pattern_code: the pattern code. - allreduce_out: the output tensor. [token_num, hidden_dim] - residual_in: the residual input tensor. [token_num, hidden_dim] - residual_out: the residual output tensor. [token_num, hidden_dim] - norm_out: the norm output tensor. [token_num, hidden_dim] - quant_out: the quant output tensor. [token_num, hidden_dim] - scale_out: the scale output tensor. Initialization referece: tests/comm/test_trtllm_allreduce_fusion.py - rms_gamma: the rms gamma tensor. [hidden_dim] - rms_eps: the rms epsilon value. - scale_factor: the scale factor. For cudaGraphs safety, it should be a tensor. - layout_code: the layout code.

Note: Regarding the use_oneshot parameter, you could force to use the one-shot strategy based on your use case. Otherwise, it would be enabled if token_num is less than the one-shot max token number (currently 128) for min-latency mode.