flashinfer.comm.trtllm_create_ipc_workspace_for_all_reduce_fusion¶
- flashinfer.comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(tp_rank: int, tp_size: int, max_token_num: int, hidden_dim, use_fp32_lamport: bool = False, group: ProcessGroup | None = None, create_metadata: bool = False) Tuple[List[List[int]], Tensor] | Tuple[List[List[int]], Tensor, dict]¶
Parameters: - tp_rank: the rank of the current process. - tp_size: the size of the process group. - max_token_num: the maximum number of tokens in a sequence. - hidden_dim: the dimension of the hidden states. - use_fp32_lamport: if True, we will use fp32 datatype in allreduce fusion. - group: the process group to use. - create_metadata: if True, return metadata dict as third element (default: False).
Returns: - If create_metadata=False: (ipc_handles, workspace_tensor) - If create_metadata=True: (ipc_handles, workspace_tensor, metadata)
where metadata contains: tp_rank, tp_size, max_token_num, hidden_dim, use_fp32_lamport, buffer_size, flag_size, lamport_comm_size, lamport_buffer_size
Note: We would init 3 IPC buffers for trtllm_custom_all_reduce_fusion. They are sized as follows: [buffer_size, flag_size, lamport_buffer_size * 3] where: - buffer_size: tp_size * max_token_num * hidden_dim * sizeof(half) - flag_size: tp_size * BarrierFlagCount * sizeof(int) - lamport_buffer_size: tp_size * max_token_num * tp_size * hidden_dim * sizeof(half)
where sizeof(elem) = 2 (fp16/bf16) or 4 (fp32 when use_fp32_lamport=True)
The workspace is passed as workspace field in AllReduceFusionParams.
We use tp_size and world_size here interchangeably (allReduceFusion).
Reference: trtllm, cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.cu, Workspace init