flashinfer.comm.trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_rmsnorm

flashinfer.comm.trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_rmsnorm(prenorm_output: Tensor, normed_output: Tensor, shard_input: Tensor, multicast_buffer_ptr: int, buffer_ptrs_dev: int, unicast_ptr: int, buffer_M: int, buffer_flags_mnnvl: Tensor, nranks: int, rank: int, gamma: Tensor, epsilon: float, residual: Tensor, launch_with_pdl: bool) None

Performs MNNVL TwoShot Allreduce + RMSNorm.

This function performs a multi-node all-reduce (sum) operation by first calling trtllm_mnnvl_all_reduce on the shard_input. After this, it performs RMSNorm on the all-reduced result, reading it directly from the multicast buffer. Note: multicast buffer is the same as the unicast buffer for the current rank.

Parameters:
  • prenorm_output – Output tensor for prenorm results

  • normed_output – Output tensor for normalized results

  • shard_input – Input tensor shard

  • multicast_buffer_ptr – Pointer address as integer for multicast buffer

  • buffer_ptrs_dev – Pointer address as integer for device buffer pointers

  • unicast_ptr – Pointer address as integer for unicast buffer

  • buffer_M – Maximum number of elements // hidden_dim

  • buffer_flags_mnnvl – Buffer flags for synchronization

  • nranks – Number of ranks in the tensor parallel group

  • rank – Current rank in the tensor parallel group

  • gamma – The gamma (norm weight) parameter for RMSNorm

  • epsilon – The epsilon parameter for RMSNorm

  • residual – The residual tensor to add

  • launch_with_pdl – Whether to launch with PDL