flashinfer.comm.vllm_all_reduce

flashinfer.comm.vllm_all_reduce(fa: int, inp: Tensor, out: Tensor, reg_buffer: int, reg_buffer_sz_bytes: int, num_ctas: int) None

Performs an out-of-place all reduce.

Parameters:
  • fa – The handle to the custom all reduce.

  • inp – The input tensor to all reduce.

  • out – The output tensor to all reduce.

  • reg_buffer – The register buffer to all reduce.

  • reg_buffer_sz_bytes – The size of the register buffer.

  • num_ctas – The number of CTAs to use for the all reduce.

  • bounds (CTA upper) –

    1. Generally, we can saturate the bandwidth even with small amount the SMs.