flashinfer.comm.trtllm_mnnvl_ar.trtllm_mnnvl_all_reduce

flashinfer.comm.trtllm_mnnvl_ar.trtllm_mnnvl_all_reduce(inp: Tensor, multicast_buffer_ptr: int, buffer_ptrs_dev: int, buffer_M: int, buffer_flags_mnnvl: Tensor, nranks: int, rank: int, wait_for_results: bool, launch_with_pdl: bool, out: Tensor | None = None) None

Perform a multi-node NVLink all-reduce operation across multiple GPUs.

This function performs an all-reduce (sum) operation using NVIDIA’s multi-node NVLink (MNNVL) technology to efficiently combine tensors across multiple GPUs and nodes.

There are 3 steps: 1. scatter each GPU’s input shard to the right unicast buffer 2. perform all-reduce on each GPU 3. broadcast the result to all GPUs

Parameters:
  • inp – Local Input Shard

  • multicast_buffer_ptr – Pointer to the multicast buffer as an integer

  • buffer_ptrs_dev – Pointer to device buffer pointers as an integer

  • buffer_M – Maximum number of elements // hidden_dim

  • buffer_flags_mnnvl – Tensor containing buffer state flags

  • nranks – Total number of ranks participating in the all-reduce

  • rank – Current process rank

  • wait_for_results – If True, store the result to out

  • launch_with_pdl – If True, launch using Programmatic Dependent Launch

  • out ([Optional]) – Output tensor to store the result (required if wait_for_results is True)