flashinfer.comm.trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_add_rmsnorm

flashinfer.comm.trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_add_rmsnorm(input: Tensor, residual_in: Tensor, gamma: Tensor, workspace: MNNVLAllReduceFusionWorkspace, epsilon: float | None = None, output: Tensor | None = None, residual_out: Tensor | None = None, launch_with_pdl: bool = False, strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, weight_bias: float = 0.0) Tuple[Tensor, Tensor]

Performs MNNVL Allreduce + Residual + RMSNorm.

This function performs a multi-node all-reduce (sum) operation by first calling trtllm_mnnvl_allreduce on the shard_input. After this, it performs residual addition and RMSNorm on the all-reduced result, reading it directly from the multicast buffer. Note: multicast buffer is the same as the unicast buffer for the current rank.

Parameters:
  • input – Input tensor [num_tokens, hidden_dim]

  • residual_in – Residual input tensor [num_tokens, hidden_dim]

  • gamma – Gamma tensor [hidden_dim]

  • workspace – MNNVLAllReduceFusionWorkspace

  • epsilon – The epsilon parameter for RMSNorm, torch.finfo.eps will be used if not provided.

  • output – Output tensor for normalized results [num_tokens, hidden_dim], empty tensor will be created if not provided.

  • residual_out – Residual output tensor [num_tokens, hidden_dim], empty tensor will be created if not provided.

  • launch_with_pdl – Whether to launch with PDL

  • strategy – MNNVLAllreduceFusionStrategy. Internal heuristics will be used if not provided.

  • weight_bias – Bias added to gamma before scaling. 0.0 (default) for standard RMSNorm (gamma * x * rsqrt(…)); 1.0 for Gemma / Qwen3.5 RMSNorm ((1 + gamma) * x * rsqrt(…)).

Returns:

Add-residual and normalized tensor [num_tokens, hidden_dim] residual_out: Add-residual tensor [num_tokens, hidden_dim]

Return type:

output