flashinfer.comm.vllm_all_reduce¶
- flashinfer.comm.vllm_all_reduce(fa: int, inp: Tensor, out: Tensor, reg_buffer: int, reg_buffer_sz_bytes: int, num_ctas: int) None ¶
Performs an out-of-place all reduce.
- Parameters:
fa – The handle to the custom all reduce.
inp – The input tensor to all reduce.
out – The output tensor to all reduce.
reg_buffer – The register buffer to all reduce.
reg_buffer_sz_bytes – The size of the register buffer.
num_ctas – The number of CTAs to use for the all reduce.
bounds (CTA upper) –
Generally, we can saturate the bandwidth even with small amount the SMs.