flashinfer.comm.vllm_all_reduce

flashinfer.comm.vllm_all_reduce(fa: int, inp: Tensor, out: Tensor, reg_buffer: int, reg_buffer_sz_bytes: int, num_ctas: int) None

Perform an out-of-place all-reduce via the vLLM custom kernel.

Parameters:
  • fa (int) – Handle returned by init_custom_ar().

  • inp (torch.Tensor) – Input tensor (rank-local contribution).

  • out (torch.Tensor) – Pre-allocated output tensor (receives the reduced result).

  • reg_buffer (int) – Device pointer to the registered buffer used by the kernel.

  • reg_buffer_sz_bytes (int) – Size of reg_buffer in bytes.

  • num_ctas (int) – Number of CTAs to launch. Upper bound is 36; small values are usually enough to saturate NVLink bandwidth.