flashinfer.comm.vllm_all_reduce¶
- flashinfer.comm.vllm_all_reduce(fa: int, inp: Tensor, out: Tensor, reg_buffer: int, reg_buffer_sz_bytes: int, num_ctas: int) None¶
Perform an out-of-place all-reduce via the vLLM custom kernel.
- Parameters:
fa (int) – Handle returned by
init_custom_ar().inp (torch.Tensor) – Input tensor (rank-local contribution).
out (torch.Tensor) – Pre-allocated output tensor (receives the reduced result).
reg_buffer (int) – Device pointer to the registered buffer used by the kernel.
reg_buffer_sz_bytes (int) – Size of
reg_bufferin bytes.num_ctas (int) – Number of CTAs to launch. Upper bound is 36; small values are usually enough to saturate NVLink bandwidth.