flashinfer.comm.trtllm_create_ipc_workspace_for_all_reduce_fusion

flashinfer.comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(tp_rank: int, tp_size: int, max_token_num: int, hidden_dim, use_fp32_lamport: bool = False, group: ProcessGroup | None = None) Tuple[List[List[int]], Tensor]

Parameters: - tp_rank: the rank of the current process. - tp_size: the size of the process group. - max_token_num: the maximum number of tokens in a sequence. - hidden_dim: the dimension of the hidden states. - use_fp32_lamport: if True, we will use fp32 datatype in allreduce fusion. - group: the process group to use.

Note: We would init 3 IPC buffers for trtllm_custom_all_reduce_fusion. They are sized as follows: [buffer_size, flag_size, lamport_buffer_size * 3] where: - buffer_size: tp_size * max_token_num * hidden_dim * sizeof(half) - flag_size: tp_size * BarrierFlagCount * sizeof(int) - lamport_buffer_size: tp_size * max(max_token_num, OneShotMaxToken) * tp_size * hidden_dim * sizeof(half)

The workspace is passed as workspace field in AllReduceFusionParams.

We use tp_size and world_size here interchangeably (allReduceFusion).

Reference: trtllm, cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.cu, Workspace init