flashinfer.comm.trtllm_create_ipc_workspace_for_all_reduce_fusion¶
- flashinfer.comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(tp_rank: int, tp_size: int, max_token_num: int, hidden_dim, use_fp32_lamport: bool = False, group: ProcessGroup | None = None, create_metadata: bool = False, comm_backend: CommBackend | None = None, use_symm_dev_mem: bool = False) Tuple[List[List[int]], Tensor] | Tuple[List[List[int]], Tensor, dict] | Tuple[List[List[int]], Tensor, List[SymmDeviceMemory], dict]¶
Parameters: - tp_rank: the rank of the current process. - tp_size: the size of the process group. - max_token_num: the maximum number of tokens in a sequence. - hidden_dim: the dimension of the hidden states. - use_fp32_lamport: if True, we will use fp32 datatype in allreduce fusion. - group: the process group to use. - create_metadata: if True, return metadata dict as third element (default: False). - comm_backend: the communication backend to use. - use_symm_dev_mem: if True, we will use symmetric device memory for the workspace.
Returns: - If create_metadata=False: (ipc_handles, workspace_tensor) - If create_metadata=True: and use_symm_dev_mem=False: (ipc_handles, workspace_tensor, metadata)
where metadata contains: tp_rank, tp_size, max_token_num, hidden_dim, use_fp32_lamport, buffer_size, flag_size, lamport_comm_size, lamport_buffer_size
If create_metadata=True: and use_symm_dev_mem=True: (ipc_handles, workspace_tensor, mem_handles,metadata) where metadata contains: tp_rank, tp_size, max_token_num, hidden_dim, use_fp32_lamport, buffer_size, flag_size, lamport_comm_size, lamport_buffer_size and mem_handles is a list of SymmDeviceMemory objects.
Note: The optional parameters make the API clunky at this time. This will be refactored in the future, at the cost of backward compatibility, where the default behavior will be create_metadata=True and use_symm_dev_mem=True.
Note: We would init 3 IPC buffers for trtllm_custom_all_reduce_fusion. They are sized as follows: [buffer_size, flag_size, lamport_buffer_size * 3] where: - buffer_size: tp_size * max_token_num * hidden_dim * sizeof(half) - flag_size: tp_size * BarrierFlagCount * sizeof(int) - lamport_buffer_size: tp_size * max_token_num * tp_size * hidden_dim * sizeof(half)
where sizeof(elem) = 2 (fp16/bf16) or 4 (fp32 when use_fp32_lamport=True)
The workspace is passed as workspace field in AllReduceFusionParams.
We use tp_size and world_size here interchangeably (allReduceFusion).
Reference: trtllm, cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.cu, Workspace init