flashinfer.comm.allreduce_fusion¶
- flashinfer.comm.allreduce_fusion(input: Tensor, workspace: AllReduceFusionWorkspace, pattern: int, launch_with_pdl: bool = False, trigger_completion_at_end: bool = True, output: Tensor | None = None, residual_out: Tensor | None = None, norm_out: Tensor | None = None, quant_out: Tensor | None = None, scale_out: Tensor | None = None, residual_in: Tensor | None = None, rms_gamma: Tensor | None = None, rms_eps: float = 1e-06, scale_factor: Tensor | float | None = None, layout_code: int | None = None, use_oneshot: bool | None = None, fp32_acc: bool = False, moe_reduction_device_num_experts: int | None = None, moe_reduction_scale_input: Tensor | None = None, moe_reduction_active_experts_token_input: Tensor | None = None, moe_reduction_token_input: Tensor | None = None, expanded_idx_to_permuted_idx: Tensor | None = None, expert_scale_factor: Tensor | None = None, shared_expert_output: Tensor | None = None, block_quant_group_size: int | None = None, weight_bias: float = 0.0) Tensor¶
AllReduce + RMSNorm fusion operation, with optional FP8/NVFP4 quantization for supported backends.
Backend is automatically determined from workspace type. If you need a different backend, create the workspace for that backend.
Supports multiple fusion patterns:
AllReduce only
AllReduce + Residual + RMSNorm
AllReduce + Residual + RMSNorm + Quantization (FP8 / NVFP4)
Note
You can reuse the same workspace with different
(num_tokens, hidden_dim)combinations as long asworkspace.is_buffer_size_sufficient(tp_size, num_tokens, hidden_dim, dtype)returnsTrue.- Parameters:
input (torch.Tensor) – Input tensor of shape
[token_num, hidden_dim].workspace (AllReduceFusionWorkspace) – Workspace object created by
create_allreduce_fusion_workspace(). Its concrete type (TRT-LLM vs MNNVL) determines the backend.pattern (int) –
Fusion pattern (
AllReduceFusionPatternconstant):kAllReduce = 0kARResidualRMSNorm = 1kARResidualRMSNormFP8Quant = 2kARResidualRMSNormFP4Quant = 3kARResidualRMSNormOutFP8Quant = 4kARResidualRMSNormOutFP4Quant = 5kMoEReductionARResidualRMSNorm = 6(TRT-LLM only)kMoEFinalizeARResidualRMSNorm = 7(TRT-LLM only)kARResidualRMSNormPerTokenGroupFP8PackedQuant = 8(TRT-LLM only)kARResidualRMSNormOutPerTokenGroupFP8PackedQuant = 9(TRT-LLM only)
MNNVL supports the standard FP8/NVFP4 quant patterns (2-5). MoE and packed group quant patterns remain TRT-LLM only.
launch_with_pdl (bool) – Use Programmatic Dependent Launch.
trigger_completion_at_end (bool) – TRT-LLM only. Controls when PDL completion is signaled.
True(default) signals after the kernel finishes (safe, no overlap).Falsesignals early, allowing the next PDL-aware kernel to overlap with this one. Only safe when the next kernel also callscudaGridDependencySynchronize(). Ignored by the MNNVL backend.output (Optional[torch.Tensor]) – Pre-allocated AllReduce output buffer, shape
[token_num, hidden_dim].residual_out (Optional[torch.Tensor]) – Pre-allocated pre-norm output (after residual add, before norm), shape
[token_num, hidden_dim].norm_out (Optional[torch.Tensor]) – Pre-allocated normalized output, shape
[token_num, hidden_dim].quant_out (Optional[torch.Tensor]) – Pre-allocated quantized output. FP8 uses shape
[token_num, hidden_dim]and NVFP4 uses shape[token_num, hidden_dim / 2].scale_out (Optional[torch.Tensor]) – Pre-allocated NVFP4 scale-factor buffer. Not used by per-tensor FP8 quantization.
residual_in (Optional[torch.Tensor]) – Residual tensor to add, shape
[token_num, hidden_dim].rms_gamma (Optional[torch.Tensor]) – RMSNorm weight, shape
[hidden_dim].rms_eps (float) – RMSNorm epsilon for numerical stability.
scale_factor (Optional[Union[torch.Tensor, float]]) – Output scale used by FP8/NVFP4 quantization.
layout_code (Optional[int]) – NVFP4 scale-factor layout (
QuantizationSFLayout). MNNVL supportsSWIZZLED_128x4andLINEAR;SWIZZLED_8x4remains TRT-LLM only.use_oneshot (Optional[bool]) –
True/Falseforces the oneshot/twoshot strategy;None(default) uses internal heuristics. When set toTruefor MNNVL, the workspace must have been allocated with a sufficiently large size.fp32_acc (bool) – TRT-LLM only. Use FP32 accumulation for AllReduce.
moe_reduction_device_num_experts (Optional[int]) – Number of local experts on this device, required for
pattern=kMoEReductionARResidualRMSNorm.moe_reduction_scale_input (Optional[torch.Tensor]) – Per-token-per-expert scales, shape
[token_num, num_experts].moe_reduction_active_experts_token_input (Optional[torch.Tensor]) – Per-token-per-expert outputs, shape
[token_num * num_experts, hidden_dim].moe_reduction_token_input (Optional[torch.Tensor]) – Per-token input (e.g. FC2 output), shape
[token_num, hidden_dim].expanded_idx_to_permuted_idx (Optional[torch.Tensor]) – Mapping from
(token, topk_idx)to permuted expert output row. Shape[token_num, top_k], dtypeint32. Required forpattern=kMoEFinalizeARResidualRMSNorm.expert_scale_factor (Optional[torch.Tensor]) – Router weights for each selected expert, shape
[token_num, top_k].shared_expert_output (Optional[torch.Tensor]) – Optional shared-expert output to add, shape
[token_num, hidden_dim].block_quant_group_size (Optional[int]) – Group size for per-token-group FP8 packed quantization patterns (TRT-LLM only).
weight_bias (float) –
Bias added to
rms_gammabefore scaling.0.0(default): standard RMSNorm (out = gamma * x * rsqrt(...)).1.0: Gemma / Qwen3.5 RMSNorm (out = (1 + gamma) * x * rsqrt(...)).
Supported by both TRT-LLM and MNNVL backends for standard RMSNorm and quant patterns (1-5), and by TRT-LLM for MoE RMSNorm variants. Ignored for
kAllReduce.
- Returns:
Output tensor for the selected pattern. Quant patterns return
quant_out, RMSNorm patterns returnnorm_out, andkAllReducereturnsoutput.- Return type:
torch.Tensor
Examples
>>> # Basic AllReduce + Residual + RMSNorm >>> workspace = create_allreduce_fusion_workspace( ... backend="auto", world_size=8, rank=0, ... max_token_num=2048, hidden_dim=4096, dtype=torch.bfloat16, ... ) >>> prenorm = torch.empty_like(hidden_states) >>> normed = torch.empty_like(hidden_states) >>> output = allreduce_fusion( ... input=hidden_states, ... workspace=workspace, ... pattern=AllReduceFusionPattern.kARResidualRMSNorm, ... launch_with_pdl=True, ... residual_out=prenorm, ... norm_out=normed, ... residual_in=residual, ... rms_gamma=norm_weight, ... )