flashinfer.comm.trtllm_mnnvl_ar.trtllm_mnnvl_all_reduce¶
- flashinfer.comm.trtllm_mnnvl_ar.trtllm_mnnvl_all_reduce(inp: Tensor, multicast_buffer_ptr: int, buffer_ptrs_dev: int, buffer_M: int, buffer_flags_mnnvl: Tensor, nranks: int, rank: int, wait_for_results: bool, launch_with_pdl: bool, out: Tensor | None = None) None ¶
Perform a multi-node NVLink all-reduce operation across multiple GPUs.
This function performs an all-reduce (sum) operation using NVIDIA’s multi-node NVLink (MNNVL) technology to efficiently combine tensors across multiple GPUs and nodes.
There are 3 steps: 1. scatter each GPU’s input shard to the right unicast buffer 2. perform all-reduce on each GPU 3. broadcast the result to all GPUs
- Parameters:
inp – Local Input Shard
multicast_buffer_ptr – Pointer to the multicast buffer as an integer
buffer_ptrs_dev – Pointer to device buffer pointers as an integer
buffer_M – Maximum number of elements // hidden_dim
buffer_flags_mnnvl – Tensor containing buffer state flags
nranks – Total number of ranks participating in the all-reduce
rank – Current process rank
wait_for_results – If True, store the result to out
launch_with_pdl – If True, launch using Programmatic Dependent Launch
out ([Optional]) – Output tensor to store the result (required if wait_for_results is True)