flashinfer.comm.create_allreduce_fusion_workspace

flashinfer.comm.create_allreduce_fusion_workspace(backend: Literal['trtllm', 'mnnvl', 'auto'] = 'auto', world_size: int = None, rank: int = None, max_token_num: int = None, hidden_dim: int = None, dtype: dtype = None, gpus_per_node: int = None, comm_backend: CommBackend | None = None, force_oneshot_support: bool = False, group: ProcessGroup | None = None) AllReduceFusionWorkspace

Create workspace for AllReduce fusion operations.

Backend selection uses topology-based checks and heuristics.

Important: Workspace Reusability The workspace is allocated based on the total size (max_token_num * hidden_dim * dtype_size). You can reuse the same workspace with different shapes as long as the total size fits.

Use workspace.is_buffer_size_sufficient(tp_size, num_tokens, hidden_dim, dtype) to check before reusing.

Parameters:
  • backend (Literal["trtllm", "mnnvl", "auto"]) – Backend to use. "auto" uses a topology-based heuristic to pick between "trtllm" and "mnnvl".

  • world_size (int) – Number of ranks in the process group.

  • rank (int) – Current rank id.

  • max_token_num (int) – Maximum number of tokens the workspace must support.

  • hidden_dim (int) – Hidden dimension size.

  • dtype (torch.dtype) – Element dtype of the communication tensors.

  • gpus_per_node (int, optional) – Number of GPUs per node (used for multi-node topology decisions). Defaults to min(torch.cuda.device_count(), world_size).

  • comm_backend (Optional[CommBackend]) – Communication backend to use for rendezvous. Defaults to the process-group’s default.

  • force_oneshot_support (bool) – If True, allocate workspace for the oneshot strategy up to the largest problem size requested. If False (default), allocate workspace for the twoshot strategy across all problem sizes and for the oneshot strategy up to the heuristic threshold. Only the MNNVL backend needs to be initialized with the correct strategy; the TRT-LLM backend works for both.

  • group (Optional[ProcessGroup]) – Process group used for symmetric-memory rendezvous (TRT-LLM backend only). Defaults to torch.distributed.group.WORLD.

Returns:

Either a TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace. The workspace type determines which backend allreduce_fusion() will dispatch to.

Return type:

AllReduceFusionWorkspace

Raises:
  • ValueError – If no suitable backend is available for the requested configuration, or if the problem size is not supported by the chosen backend.

  • RuntimeError – If an explicit backend argument is passed that does not match any known backend implementation.

Examples

>>> # Auto-select best backend
>>> workspace = create_allreduce_fusion_workspace(
...     backend="auto",
...     world_size=8,
...     rank=0,
...     max_token_num=2048,
...     hidden_dim=4096,
...     dtype=torch.bfloat16,
... )
>>> print(workspace.backend)  # "trtllm"
>>> # Explicit backend selection
>>> workspace = create_allreduce_fusion_workspace(
...     backend="mnnvl",
...     world_size=16,
...     rank=0,
...     max_token_num=2048,
...     hidden_dim=4096,
...     dtype=torch.bfloat16,
... )