flashinfer.comm.allreduce_fusion

flashinfer.comm.allreduce_fusion(input: Tensor, workspace: AllReduceFusionWorkspace, pattern: int, launch_with_pdl: bool = False, trigger_completion_at_end: bool = True, output: Tensor | None = None, residual_out: Tensor | None = None, norm_out: Tensor | None = None, quant_out: Tensor | None = None, scale_out: Tensor | None = None, residual_in: Tensor | None = None, rms_gamma: Tensor | None = None, rms_eps: float = 1e-06, scale_factor: Tensor | float | None = None, layout_code: int | None = None, use_oneshot: bool | None = None, fp32_acc: bool = False, moe_reduction_device_num_experts: int | None = None, moe_reduction_scale_input: Tensor | None = None, moe_reduction_active_experts_token_input: Tensor | None = None, moe_reduction_token_input: Tensor | None = None, expanded_idx_to_permuted_idx: Tensor | None = None, expert_scale_factor: Tensor | None = None, shared_expert_output: Tensor | None = None) Tensor

AllReduce + RMSNorm fusion operation.

Backend is automatically determined from workspace type. If you need another backend, create the workspace for the desired backend.

Supports multiple fusion patterns: - AllReduce only - AllReduce + Residual + RMSNorm - AllReduce + Residual + RMSNorm + Quantization (FP8/FP4)

Note on Workspace Reusability: You can reuse the same workspace with different (token_num, hidden_dim) combinations as long as workspace.is_buffer_size_sufficient(token_num, hidden_dim, tp_size, dtype) returns True.

Parameters:
  • input – Input tensor [token_num, hidden_dim]

  • workspace – Workspace object (type determines backend, see create_allreduce_fusion_workspace)

  • pattern – Fusion pattern (AllReduceFusionPattern constant, 0-7) - kAllReduce = 0 - kARResidualRMSNorm = 1 - kARResidualRMSNormFP8Quant = 2 - kARResidualRMSNormFP4Quant = 3 - kARResidualRMSNormOutFP8Quant = 4 - kARResidualRMSNormOutFP4Quant = 5 - kMoEReductionARResidualRMSNorm = 6 (trtllm only) - kMoEFinalizeARResidualRMSNorm = 7 (trtllm only) Note: MNNVL only supports patterns 0 and 1 Note: MOE patterns (6-7) only support trtllm backend

  • launch_with_pdl – Use Programmatic Dependent Launch

  • trigger_completion_at_end – [trtllm only] Controls when PDL completion is signaled. True (default): signal completion after the kernel finishes (safe, no overlap). False: signal completion early, allowing the next PDL-aware kernel to overlap with this one. Only safe when the subsequent kernel also uses cudaGridDependencySynchronize(). Ignored by MNNVL backend.

  • tensors (# ===== OUTPUT)

  • output – AllReduce output [token_num, hidden_dim]

  • residual_out – Prenorm output (after residual add, before norm) [token_num, hidden_dim]

  • norm_out – Normalized output [token_num, hidden_dim]

  • quant_out – Quantized output [token_num, hidden_dim] [trtllm only]

  • scale_out – Quantization scale factors [trtllm only]

  • ===== (# ===== Control parameters)

  • residual_in – Residual tensor to ADD [token_num, hidden_dim]

  • rms_gamma – RMSNorm weight [hidden_dim]

  • rms_eps – RMSNorm epsilon for numerical stability

  • scale_factor – Input scale factor for quantization [trtllm only]

  • layout_code – Scale factor layout (QuantizationSFLayout) [trtllm only]

  • =====

  • use_oneshot – Use oneshot strategy vs twoshot If None, uses internal heuristics. Note: when explicitly set to True, the MNNVL backend needs to be initialized with a sufficiently large workspace.

  • fp32_acc – [trtllm only] Use FP32 accumulation for AllReduce

  • parameters (# ===== MOE Finalize)

  • moe_reduction_device_num_experts – Number of local experts on this device

  • moe_reduction_scale_input – Per-token-per-expert scale [token_num, num_experts]

  • moe_reduction_active_experts_token_input – Per-token-per-expert outputs [token_num * num_experts, hidden_dim]

  • moe_reduction_token_input – Per-token input (e.g. FC2 output) [token_num, hidden_dim]

  • parameters

  • expanded_idx_to_permuted_idx – Mapping from (token, topk_idx) to permuted expert output row. Shape [token_num, top_k], dtype int32.

  • expert_scale_factor – Router weights for each selected expert [token_num, top_k]

  • shared_expert_output – Optional shared expert output to add [token_num, hidden_dim]

Returns:

Output tensor (typically norm_out for fusion cases, output otherwise)

Examples

>>> # Basic AllReduce + Residual + RMSNorm
>>> workspace = create_allreduce_fusion_workspace(
...     backend="auto",
...     world_size=8,
...     rank=0,
...     max_token_num=2048,
...     hidden_dim=4096,
...     dtype=torch.bfloat16,
... )
>>>
>>> # Pre-allocate output tensors
>>> prenorm = torch.empty_like(hidden_states)
>>> normed = torch.empty_like(hidden_states)
>>>
>>> # Call fusion - backend inferred from workspace type
>>> output = allreduce_fusion(
...     input=hidden_states,
...     workspace=workspace,
...     pattern=AllReduceFusionPattern.kARResidualRMSNorm,
...     launch_with_pdl=True,
...     residual_out=prenorm,
...     norm_out=normed,
...     residual_in=residual,
...     rms_gamma=norm_weight
... )
>>> # output == normed (final result)
>>> # With FP8 quantization
>>> quant = torch.empty_like(hidden_states, dtype=torch.float8_e4m3fn)
>>> scales = torch.empty(token_num * hidden_dim // 16, dtype=torch.float16)
>>>
>>> output = allreduce_fusion(
...     input=hidden_states,
...     workspace=workspace,
...     pattern=AllReduceFusionPattern.kARResidualRMSNormFP8Quant,
...     norm_out=normed,
...     quant_out=quant,
...     scale_out=scales,
...     residual_in=residual,
...     rms_gamma=norm_weight,
...     scale_factor=scale_tensor
... )
>>> # MoE Finalize + AllReduce + Residual + RMSNorm (e.g. DeepSeek)
>>> # input = permuted expert outputs [max_permuted_count, hidden_dim]
>>> # expanded_idx_to_permuted_idx = [token_num, top_k] mapping
>>> normed = torch.empty(token_num, hidden_dim, dtype=torch.bfloat16, device="cuda")
>>> residual_updated = torch.empty_like(residual)
>>> output = allreduce_fusion(
...     input=permuted_expert_output,
...     workspace=workspace,
...     pattern=AllReduceFusionPattern.kMoEFinalizeARResidualRMSNorm,
...     launch_with_pdl=True,
...     residual_in=residual,
...     residual_out=residual_updated,
...     norm_out=normed,
...     rms_gamma=norm_weight,
...     rms_eps=1e-6,
...     expanded_idx_to_permuted_idx=idx_mapping,
...     expert_scale_factor=router_weights,
...     shared_expert_output=shared_expert_out,
... )