flashinfer.comm.trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_add_rmsnorm¶
- flashinfer.comm.trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_add_rmsnorm(input: Tensor, residual_in: Tensor, gamma: Tensor, workspace: MNNVLAllReduceFusionWorkspace, epsilon: float | None = None, output: Tensor | None = None, residual_out: Tensor | None = None, launch_with_pdl: bool = False, strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, weight_bias: float = 0.0) Tuple[Tensor, Tensor]¶
Performs MNNVL Allreduce + Residual + RMSNorm.
This function performs a multi-node all-reduce (sum) operation by first calling trtllm_mnnvl_allreduce on the shard_input. After this, it performs residual addition and RMSNorm on the all-reduced result, reading it directly from the multicast buffer. Note: multicast buffer is the same as the unicast buffer for the current rank.
- Parameters:
input – Input tensor [num_tokens, hidden_dim]
residual_in – Residual input tensor [num_tokens, hidden_dim]
gamma – Gamma tensor [hidden_dim]
workspace – MNNVLAllReduceFusionWorkspace
epsilon – The epsilon parameter for RMSNorm, torch.finfo.eps will be used if not provided.
output – Output tensor for normalized results [num_tokens, hidden_dim], empty tensor will be created if not provided.
residual_out – Residual output tensor [num_tokens, hidden_dim], empty tensor will be created if not provided.
launch_with_pdl – Whether to launch with PDL
strategy – MNNVLAllreduceFusionStrategy. Internal heuristics will be used if not provided.
weight_bias – Bias added to gamma before scaling. 0.0 (default) for standard RMSNorm (gamma * x * rsqrt(…)); 1.0 for Gemma / Qwen3.5 RMSNorm ((1 + gamma) * x * rsqrt(…)).
- Returns:
Add-residual and normalized tensor [num_tokens, hidden_dim] residual_out: Add-residual tensor [num_tokens, hidden_dim]
- Return type:
output