flashinfer.comm.allreduce_fusion

flashinfer.comm.allreduce_fusion(input: Tensor, workspace: AllReduceFusionWorkspace, pattern: int, launch_with_pdl: bool = False, trigger_completion_at_end: bool = True, output: Tensor | None = None, residual_out: Tensor | None = None, norm_out: Tensor | None = None, quant_out: Tensor | None = None, scale_out: Tensor | None = None, residual_in: Tensor | None = None, rms_gamma: Tensor | None = None, rms_eps: float = 1e-06, scale_factor: Tensor | float | None = None, layout_code: int | None = None, use_oneshot: bool | None = None, fp32_acc: bool = False, moe_reduction_device_num_experts: int | None = None, moe_reduction_scale_input: Tensor | None = None, moe_reduction_active_experts_token_input: Tensor | None = None, moe_reduction_token_input: Tensor | None = None, expanded_idx_to_permuted_idx: Tensor | None = None, expert_scale_factor: Tensor | None = None, shared_expert_output: Tensor | None = None, block_quant_group_size: int | None = None, weight_bias: float = 0.0) Tensor

AllReduce + RMSNorm fusion operation, with optional FP8/NVFP4 quantization for supported backends.

Backend is automatically determined from workspace type. If you need a different backend, create the workspace for that backend.

Supports multiple fusion patterns:

  • AllReduce only

  • AllReduce + Residual + RMSNorm

  • AllReduce + Residual + RMSNorm + Quantization (FP8 / NVFP4)

Note

You can reuse the same workspace with different (num_tokens, hidden_dim) combinations as long as workspace.is_buffer_size_sufficient(tp_size, num_tokens, hidden_dim, dtype) returns True.

Parameters:
  • input (torch.Tensor) – Input tensor of shape [token_num, hidden_dim].

  • workspace (AllReduceFusionWorkspace) – Workspace object created by create_allreduce_fusion_workspace(). Its concrete type (TRT-LLM vs MNNVL) determines the backend.

  • pattern (int) –

    Fusion pattern (AllReduceFusionPattern constant):

    • kAllReduce = 0

    • kARResidualRMSNorm = 1

    • kARResidualRMSNormFP8Quant = 2

    • kARResidualRMSNormFP4Quant = 3

    • kARResidualRMSNormOutFP8Quant = 4

    • kARResidualRMSNormOutFP4Quant = 5

    • kMoEReductionARResidualRMSNorm = 6 (TRT-LLM only)

    • kMoEFinalizeARResidualRMSNorm = 7 (TRT-LLM only)

    • kARResidualRMSNormPerTokenGroupFP8PackedQuant = 8 (TRT-LLM only)

    • kARResidualRMSNormOutPerTokenGroupFP8PackedQuant = 9 (TRT-LLM only)

    MNNVL supports the standard FP8/NVFP4 quant patterns (2-5). MoE and packed group quant patterns remain TRT-LLM only.

  • launch_with_pdl (bool) – Use Programmatic Dependent Launch.

  • trigger_completion_at_end (bool) – TRT-LLM only. Controls when PDL completion is signaled. True (default) signals after the kernel finishes (safe, no overlap). False signals early, allowing the next PDL-aware kernel to overlap with this one. Only safe when the next kernel also calls cudaGridDependencySynchronize(). Ignored by the MNNVL backend.

  • output (Optional[torch.Tensor]) – Pre-allocated AllReduce output buffer, shape [token_num, hidden_dim].

  • residual_out (Optional[torch.Tensor]) – Pre-allocated pre-norm output (after residual add, before norm), shape [token_num, hidden_dim].

  • norm_out (Optional[torch.Tensor]) – Pre-allocated normalized output, shape [token_num, hidden_dim].

  • quant_out (Optional[torch.Tensor]) – Pre-allocated quantized output. FP8 uses shape [token_num, hidden_dim] and NVFP4 uses shape [token_num, hidden_dim / 2].

  • scale_out (Optional[torch.Tensor]) – Pre-allocated NVFP4 scale-factor buffer. Not used by per-tensor FP8 quantization.

  • residual_in (Optional[torch.Tensor]) – Residual tensor to add, shape [token_num, hidden_dim].

  • rms_gamma (Optional[torch.Tensor]) – RMSNorm weight, shape [hidden_dim].

  • rms_eps (float) – RMSNorm epsilon for numerical stability.

  • scale_factor (Optional[Union[torch.Tensor, float]]) – Output scale used by FP8/NVFP4 quantization.

  • layout_code (Optional[int]) – NVFP4 scale-factor layout (QuantizationSFLayout). MNNVL supports SWIZZLED_128x4 and LINEAR; SWIZZLED_8x4 remains TRT-LLM only.

  • use_oneshot (Optional[bool]) – True/False forces the oneshot/twoshot strategy; None (default) uses internal heuristics. When set to True for MNNVL, the workspace must have been allocated with a sufficiently large size.

  • fp32_acc (bool) – TRT-LLM only. Use FP32 accumulation for AllReduce.

  • moe_reduction_device_num_experts (Optional[int]) – Number of local experts on this device, required for pattern=kMoEReductionARResidualRMSNorm.

  • moe_reduction_scale_input (Optional[torch.Tensor]) – Per-token-per-expert scales, shape [token_num, num_experts].

  • moe_reduction_active_experts_token_input (Optional[torch.Tensor]) – Per-token-per-expert outputs, shape [token_num * num_experts, hidden_dim].

  • moe_reduction_token_input (Optional[torch.Tensor]) – Per-token input (e.g. FC2 output), shape [token_num, hidden_dim].

  • expanded_idx_to_permuted_idx (Optional[torch.Tensor]) – Mapping from (token, topk_idx) to permuted expert output row. Shape [token_num, top_k], dtype int32. Required for pattern=kMoEFinalizeARResidualRMSNorm.

  • expert_scale_factor (Optional[torch.Tensor]) – Router weights for each selected expert, shape [token_num, top_k].

  • shared_expert_output (Optional[torch.Tensor]) – Optional shared-expert output to add, shape [token_num, hidden_dim].

  • block_quant_group_size (Optional[int]) – Group size for per-token-group FP8 packed quantization patterns (TRT-LLM only).

  • weight_bias (float) –

    Bias added to rms_gamma before scaling.

    • 0.0 (default): standard RMSNorm (out = gamma * x * rsqrt(...)).

    • 1.0: Gemma / Qwen3.5 RMSNorm (out = (1 + gamma) * x * rsqrt(...)).

    Supported by both TRT-LLM and MNNVL backends for standard RMSNorm and quant patterns (1-5), and by TRT-LLM for MoE RMSNorm variants. Ignored for kAllReduce.

Returns:

Output tensor for the selected pattern. Quant patterns return quant_out, RMSNorm patterns return norm_out, and kAllReduce returns output.

Return type:

torch.Tensor

Examples

>>> # Basic AllReduce + Residual + RMSNorm
>>> workspace = create_allreduce_fusion_workspace(
...     backend="auto", world_size=8, rank=0,
...     max_token_num=2048, hidden_dim=4096, dtype=torch.bfloat16,
... )
>>> prenorm = torch.empty_like(hidden_states)
>>> normed = torch.empty_like(hidden_states)
>>> output = allreduce_fusion(
...     input=hidden_states,
...     workspace=workspace,
...     pattern=AllReduceFusionPattern.kARResidualRMSNorm,
...     launch_with_pdl=True,
...     residual_out=prenorm,
...     norm_out=normed,
...     residual_in=residual,
...     rms_gamma=norm_weight,
... )