flashinfer.comm.allreduce_fusion¶
- flashinfer.comm.allreduce_fusion(input: Tensor, workspace: AllReduceFusionWorkspace, pattern: int, launch_with_pdl: bool = False, trigger_completion_at_end: bool = True, output: Tensor | None = None, residual_out: Tensor | None = None, norm_out: Tensor | None = None, quant_out: Tensor | None = None, scale_out: Tensor | None = None, residual_in: Tensor | None = None, rms_gamma: Tensor | None = None, rms_eps: float = 1e-06, scale_factor: Tensor | float | None = None, layout_code: int | None = None, use_oneshot: bool | None = None, fp32_acc: bool = False, moe_reduction_device_num_experts: int | None = None, moe_reduction_scale_input: Tensor | None = None, moe_reduction_active_experts_token_input: Tensor | None = None, moe_reduction_token_input: Tensor | None = None, expanded_idx_to_permuted_idx: Tensor | None = None, expert_scale_factor: Tensor | None = None, shared_expert_output: Tensor | None = None) Tensor¶
AllReduce + RMSNorm fusion operation.
Backend is automatically determined from workspace type. If you need another backend, create the workspace for the desired backend.
Supports multiple fusion patterns: - AllReduce only - AllReduce + Residual + RMSNorm - AllReduce + Residual + RMSNorm + Quantization (FP8/FP4)
Note on Workspace Reusability: You can reuse the same workspace with different (token_num, hidden_dim) combinations as long as workspace.is_buffer_size_sufficient(token_num, hidden_dim, tp_size, dtype) returns True.
- Parameters:
input – Input tensor [token_num, hidden_dim]
workspace – Workspace object (type determines backend, see create_allreduce_fusion_workspace)
pattern – Fusion pattern (AllReduceFusionPattern constant, 0-7) - kAllReduce = 0 - kARResidualRMSNorm = 1 - kARResidualRMSNormFP8Quant = 2 - kARResidualRMSNormFP4Quant = 3 - kARResidualRMSNormOutFP8Quant = 4 - kARResidualRMSNormOutFP4Quant = 5 - kMoEReductionARResidualRMSNorm = 6 (trtllm only) - kMoEFinalizeARResidualRMSNorm = 7 (trtllm only) Note: MNNVL only supports patterns 0 and 1 Note: MOE patterns (6-7) only support trtllm backend
launch_with_pdl – Use Programmatic Dependent Launch
trigger_completion_at_end – [trtllm only] Controls when PDL completion is signaled. True (default): signal completion after the kernel finishes (safe, no overlap). False: signal completion early, allowing the next PDL-aware kernel to overlap with this one. Only safe when the subsequent kernel also uses cudaGridDependencySynchronize(). Ignored by MNNVL backend.
tensors (# ===== OUTPUT)
output – AllReduce output [token_num, hidden_dim]
residual_out – Prenorm output (after residual add, before norm) [token_num, hidden_dim]
norm_out – Normalized output [token_num, hidden_dim]
quant_out – Quantized output [token_num, hidden_dim] [trtllm only]
scale_out – Quantization scale factors [trtllm only]
===== (# ===== Control parameters)
residual_in – Residual tensor to ADD [token_num, hidden_dim]
rms_gamma – RMSNorm weight [hidden_dim]
rms_eps – RMSNorm epsilon for numerical stability
scale_factor – Input scale factor for quantization [trtllm only]
layout_code – Scale factor layout (QuantizationSFLayout) [trtllm only]
=====
use_oneshot – Use oneshot strategy vs twoshot If None, uses internal heuristics. Note: when explicitly set to True, the MNNVL backend needs to be initialized with a sufficiently large workspace.
fp32_acc – [trtllm only] Use FP32 accumulation for AllReduce
parameters (# ===== MOE Finalize)
moe_reduction_device_num_experts – Number of local experts on this device
moe_reduction_scale_input – Per-token-per-expert scale [token_num, num_experts]
moe_reduction_active_experts_token_input – Per-token-per-expert outputs [token_num * num_experts, hidden_dim]
moe_reduction_token_input – Per-token input (e.g. FC2 output) [token_num, hidden_dim]
parameters
expanded_idx_to_permuted_idx – Mapping from (token, topk_idx) to permuted expert output row. Shape [token_num, top_k], dtype int32.
expert_scale_factor – Router weights for each selected expert [token_num, top_k]
shared_expert_output – Optional shared expert output to add [token_num, hidden_dim]
- Returns:
Output tensor (typically norm_out for fusion cases, output otherwise)
Examples
>>> # Basic AllReduce + Residual + RMSNorm >>> workspace = create_allreduce_fusion_workspace( ... backend="auto", ... world_size=8, ... rank=0, ... max_token_num=2048, ... hidden_dim=4096, ... dtype=torch.bfloat16, ... ) >>> >>> # Pre-allocate output tensors >>> prenorm = torch.empty_like(hidden_states) >>> normed = torch.empty_like(hidden_states) >>> >>> # Call fusion - backend inferred from workspace type >>> output = allreduce_fusion( ... input=hidden_states, ... workspace=workspace, ... pattern=AllReduceFusionPattern.kARResidualRMSNorm, ... launch_with_pdl=True, ... residual_out=prenorm, ... norm_out=normed, ... residual_in=residual, ... rms_gamma=norm_weight ... ) >>> # output == normed (final result)
>>> # With FP8 quantization >>> quant = torch.empty_like(hidden_states, dtype=torch.float8_e4m3fn) >>> scales = torch.empty(token_num * hidden_dim // 16, dtype=torch.float16) >>> >>> output = allreduce_fusion( ... input=hidden_states, ... workspace=workspace, ... pattern=AllReduceFusionPattern.kARResidualRMSNormFP8Quant, ... norm_out=normed, ... quant_out=quant, ... scale_out=scales, ... residual_in=residual, ... rms_gamma=norm_weight, ... scale_factor=scale_tensor ... )
>>> # MoE Finalize + AllReduce + Residual + RMSNorm (e.g. DeepSeek) >>> # input = permuted expert outputs [max_permuted_count, hidden_dim] >>> # expanded_idx_to_permuted_idx = [token_num, top_k] mapping >>> normed = torch.empty(token_num, hidden_dim, dtype=torch.bfloat16, device="cuda") >>> residual_updated = torch.empty_like(residual) >>> output = allreduce_fusion( ... input=permuted_expert_output, ... workspace=workspace, ... pattern=AllReduceFusionPattern.kMoEFinalizeARResidualRMSNorm, ... launch_with_pdl=True, ... residual_in=residual, ... residual_out=residual_updated, ... norm_out=normed, ... rms_gamma=norm_weight, ... rms_eps=1e-6, ... expanded_idx_to_permuted_idx=idx_mapping, ... expert_scale_factor=router_weights, ... shared_expert_output=shared_expert_out, ... )