flashinfer.comm.trtllm_mnnvl_ar.trtllm_mnnvl_allreduce¶
- flashinfer.comm.trtllm_mnnvl_ar.trtllm_mnnvl_allreduce(input: Tensor, workspace: MNNVLAllReduceFusionWorkspace, launch_with_pdl: bool, output: Tensor | None = None, strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO) Tensor¶
Perform an MNNVL all-reduce sum across tensor-parallel ranks.
inputmust be a 2-D local shard with shape[num_tokens, hidden_dim]. The result has the same shape and is written tooutputwhen provided, or to a newly allocated tensor otherwise.workspacemust be anMNNVLAllReduceFusionWorkspacecreated for the same tensor-parallel group and with enough buffer capacity for the selected strategy.MNNVL supports two execution strategies:
ONESHOTstores each rank’s local shard to the peer-visible workspace and each rank performs the reduction locally. This is the low-latency path used for smaller payloads.TWOSHOTscatters token slices, reduces each slice on its destination rank, then broadcasts the reduced result. This is the throughput-oriented path for larger payloads.
With
AUTO, FlashInfer selects the strategy from the payload size. BothONESHOTandTWOSHOTare fully deterministic across ranks.This allreduce-only helper does not perform quantization. Use
trtllm_mnnvl_fused_allreduce_add_rmsnorm_quant(), or the unifiedflashinfer.comm.allreduce_fusion()API with FP8/NVFP4AllReduceFusionPatternvalues, when the post-RMSNorm quantized output is needed.Determinism:
ONESHOTandTWOSHOTuse the exact same reduction order on each rank.ONESHOTkeeps the local rank’s value in registers; only remote ranks are volatile-loaded from the Lamport buffer.ONESHOTuses a rank-specialized fast path forworld_size <= 8. Larger world sizes use a compact deterministic fallback because the runtime benefit is thin and specializing every rank significantly increases JIT compile time.
- Parameters:
input – Local Input Shard [num_tokens, hidden_dim]
workspace – MNNVLAllReduceFusionWorkspace
launch_with_pdl – Whether to launch with PDL
output – Output tensor to store the result, empty tensor will be created if not provided.
strategy – MNNVLAllreduceFusionStrategy. Internal heuristics will be used if not provided.
- Returns:
Reduced tensor [num_tokens, hidden_dim]
- Return type:
output