flashinfer.comm.trtllm_mnnvl_ar.trtllm_mnnvl_allreduce

flashinfer.comm.trtllm_mnnvl_ar.trtllm_mnnvl_allreduce(input: Tensor, workspace: MNNVLAllReduceFusionWorkspace, launch_with_pdl: bool, output: Tensor | None = None, strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO) Tensor

Perform an MNNVL all-reduce sum across tensor-parallel ranks.

input must be a 2-D local shard with shape [num_tokens, hidden_dim]. The result has the same shape and is written to output when provided, or to a newly allocated tensor otherwise. workspace must be an MNNVLAllReduceFusionWorkspace created for the same tensor-parallel group and with enough buffer capacity for the selected strategy.

MNNVL supports two execution strategies:

  • ONESHOT stores each rank’s local shard to the peer-visible workspace and each rank performs the reduction locally. This is the low-latency path used for smaller payloads.

  • TWOSHOT scatters token slices, reduces each slice on its destination rank, then broadcasts the reduced result. This is the throughput-oriented path for larger payloads.

With AUTO, FlashInfer selects the strategy from the payload size. Both ONESHOT and TWOSHOT are fully deterministic across ranks.

This allreduce-only helper does not perform quantization. Use trtllm_mnnvl_fused_allreduce_add_rmsnorm_quant(), or the unified flashinfer.comm.allreduce_fusion() API with FP8/NVFP4 AllReduceFusionPattern values, when the post-RMSNorm quantized output is needed.

Determinism:

  • ONESHOT and TWOSHOT use the exact same reduction order on each rank.

  • ONESHOT keeps the local rank’s value in registers; only remote ranks are volatile-loaded from the Lamport buffer.

  • ONESHOT uses a rank-specialized fast path for world_size <= 8. Larger world sizes use a compact deterministic fallback because the runtime benefit is thin and specializing every rank significantly increases JIT compile time.

Parameters:
  • input – Local Input Shard [num_tokens, hidden_dim]

  • workspace – MNNVLAllReduceFusionWorkspace

  • launch_with_pdl – Whether to launch with PDL

  • output – Output tensor to store the result, empty tensor will be created if not provided.

  • strategy – MNNVLAllreduceFusionStrategy. Internal heuristics will be used if not provided.

Returns:

Reduced tensor [num_tokens, hidden_dim]

Return type:

output