flashinfer.comm

This module provides communication primitives and utilities for distributed computing, including CUDA IPC, AllReduce operations, and memory management utilities.

CUDA IPC Utilities

CudaRTLibrary([so_file])

create_shared_buffer(size_in_bytes[, group])

Allocate a buffer and share it across the process group via CUDA IPC.

free_shared_buffer(pointers[, group])

Free a shared buffer previously created by create_shared_buffer().

DLPack Utilities

pack_strided_memory(ptr, segment_size, ...)

Pack a strided device allocation as a PyTorch tensor view.

Mapping Utilities

Mapping([world_size, rank, gpus_per_node, ...])

A node with 8 GPUs, tp_size = 4, cp_size = 1, pp_size = 2

TensorRT-LLM AllReduce

Types and Enums

Core Operations

trtllm_allreduce_fusion(allreduce_in, ...[, ...])

Parameters: - allreduce_in: the input tensor. [token_num, hidden_dim] - world_size: the size of the process group. - world_rank: the rank of the current process. - token_num: the number of tokens in the sequence. - hidden_dim: the dimension of the hidden states. - workspace_ptrs: the workspace pointers. - launch_with_pdl: whether to launch with pdl. - use_oneshot: whether to use oneshot. If None, internal heuristics will be used. - trigger_completion_at_end: whether to trigger completion at the end. - fp32_acc: whether to use fp32 accumulation. - pattern_code: the pattern code. - allreduce_out: the output tensor. [token_num, hidden_dim] - residual_in: the residual input tensor. [token_num, hidden_dim] - residual_out: the residual output tensor. [token_num, hidden_dim] - norm_out: the norm output tensor. [token_num, hidden_dim] - quant_out: the quant output tensor. [token_num, hidden_dim] - scale_out: the scale output tensor. Initialization referece: tests/comm/test_trtllm_allreduce_fusion.py - rms_gamma: the rms gamma tensor. [hidden_dim] - rms_eps: the rms epsilon value. - scale_factor: the scale factor. For cudaGraphs safety, it should be a tensor. - layout_code: the layout code. - metadata: optional workspace metadata dict from create_ipc_workspace_for_all_reduce_fusion. If provided, validates that token_num <= max_token_num, world_size == tp_size, and hidden_dim == workspace hidden_dim. Raises ValueError if validation fails. - block_quant_group_size: group size (in elements along hidden_dim) for per-token-group block-wise FP8 quantization patterns (e.g. kPerTokenGroupFP8Packed / DeepSeek-style FP8 with UE8M0 packed scales). Number of consecutive elements that share a single scale factor. Must be > 0 and divide hidden_dim when the pattern requires it; ignored (treated as 0 / unused) for patterns that do not perform block-quantization. - weight_bias: bias added to rms_gamma before scaling. None or 0.0 -> standard RMSNorm (out = gamma * x * rsqrt(...)). 1.0 -> Gemma / Qwen3.5 RMSNorm (out = (1 + gamma) * x * rsqrt(...)). Ignored for kAllReduce and quant-only patterns that don't apply RMSNorm.

trtllm_custom_all_reduce(inp, out, tp_size, ...)

Parameters: - inp: the input tensor.

trtllm_moe_allreduce_fusion(world_size, ...)

Parameters: - world_size: the size of the process group. - world_rank: the rank of the current process. - token_num: the number of tokens in the sequence. - hidden_dim: the dimension of the hidden states. - workspace_ptrs: the workspace pointers. - launch_with_pdl: whether to launch with pdl. - residual_in: the residual input tensor. [token_num, hidden_dim] - rms_gamma: the rms gamma tensor. [hidden_dim] - rms_eps: the rms epsilon value. - scale_factor: the scale factor. - moe_reduction_device_num_experts: the number of experts. - moe_reduction_scale_input: the scale input tensor. [token_num, hidden_dim] - moe_reduction_active_experts_token_input: the active experts token input tensor. [token_num, hidden_dim] - moe_reduction_token_input: the token input tensor. [token_num, hidden_dim] - layout_code: the layout code. - moe_allreduce_out: the moe allreduce output tensor. [token_num, hidden_dim] - residual_out: the residual output tensor. [token_num, hidden_dim] - norm_out: the norm output tensor. [token_num, hidden_dim] - quant_out: the quant output tensor. [token_num // 4, hidden_dim], fp16/bf16 -> fp4 - scale_out: the scale output tensor. Initialization referece: tests/comm/test_trtllm_moe_allreduce_fusion.py - weight_bias: bias added to rms_gamma before scaling. None or 0.0 -> standard RMSNorm (out = gamma * x * rsqrt(...)). 1.0 -> Gemma / Qwen3.5 RMSNorm (out = (1 + gamma) * x * rsqrt(...)).

trtllm_moe_finalize_allreduce_fusion(...[, ...])

Parameters: - allreduce_in: the input tensor. [token_num, top_k, hidden_dim] - residual_in: the residual input tensor. [token_num, hidden_dim] - norm_weight: the norm weight tensor. [hidden_dim] - expanded_idx_to_permuted_idx: the expanded index to permuted index tensor. [token_num, top_k] - norm_out: the norm output tensor. [token_num, hidden_dim] - residual_out: the residual output tensor. [token_num, hidden_dim] - quant_out: the quant output tensor. [token_num // 4, hidden_dim], fp16/bf16 -> fp4 - scale_out: the scale output tensor. [token_num // SF_VEC_SIZE, hidden_dim], fp16/bf16 -> fp4 - workspace_ptrs: the workspace pointers. - launch_with_pdl: whether to launch with pdl. - world_rank: the rank of the current process. - world_size: the size of the process group. - eps: the epsilon value. - shared_expert_output: the shared expert output tensor. [token_num, hidden_dim] - expert_scale_factor: the expert scale factor tensor. [token_num, top_k] - routed_scaling_factor: the routed scaling factor. - weight_bias: bias added to rms_gamma before scaling. None or 0.0 -> standard RMSNorm (out = gamma * x * rsqrt(...)). 1.0 -> Gemma / Qwen3.5 RMSNorm (out = (1 + gamma) * x * rsqrt(...)).

Workspace Management

trtllm_create_ipc_workspace_for_all_reduce(...)

Parameters: - rank: the rank of the current process.

trtllm_create_ipc_workspace_for_all_reduce_fusion(...)

Parameters: - tp_rank: the rank of the current process.

trtllm_destroy_ipc_workspace_for_all_reduce(...)

Destroy a workspace created by trtllm_create_ipc_workspace_for_all_reduce.

trtllm_destroy_ipc_workspace_for_all_reduce_fusion(...)

Destroy a workspace created by trtllm_create_ipc_workspace_for_all_reduce_fusion.

Initialization and Utilities

trtllm_lamport_initialize(buffer_ptr, size, ...)

Initialize a single Lamport-style buffer to negative zero.

trtllm_lamport_initialize_all(buffer_0_ptr, ...)

Initialize three Lamport buffers to negative zero.

compute_fp4_swizzled_layout_sf_size(...)

Compute the padded size (rows times columns) of the FP4 swizzled layout.

Unified AllReduce Fusion API

allreduce_fusion(input, workspace, pattern)

AllReduce + RMSNorm fusion operation, with optional FP8/NVFP4 quantization for supported backends.

create_allreduce_fusion_workspace([backend, ...])

Create workspace for AllReduce fusion operations.

AllReduceFusionWorkspace(world_size, rank)

Base class for AllReduce fusion workspaces.

TRTLLMAllReduceFusionWorkspace(tp_size, ...)

TensorRT-LLM workspace for AllReduce fusion.

MNNVLAllReduceFusionWorkspace(mapping[, ...])

vLLM AllReduce

vllm_all_reduce(fa, inp, out, reg_buffer, ...)

Perform an out-of-place all-reduce via the vLLM custom kernel.

vllm_dispose(fa)

Release the resources held by a vLLM custom all-reduce handle.

vllm_init_custom_ar(ipc_tensors, rank_data, ...)

Initialize the vLLM custom all-reduce backend.

vllm_register_buffer(fa, fake_ipc_ptrs)

Register a peer's IPC-shared buffer with the local all-reduce handle.

vllm_register_graph_buffers(fa, handles, offsets)

Register graph-capture buffers across the all-reduce world.

vllm_get_graph_buffer_ipc_meta(fa)

Return IPC metadata for graph-capture buffers.

vllm_meta_size()

Return the size of the vLLM all-reduce metadata structure in bytes.

TensorRT-LLM MNNVL AllReduce

trtllm_mnnvl_all_reduce(inp, ...[, out])

Deprecated pointer-based MNNVL all-reduce API.

trtllm_mnnvl_allreduce(input, workspace, ...)

Perform an MNNVL all-reduce sum across tensor-parallel ranks.

trtllm_mnnvl_fused_allreduce_add_rmsnorm(...)

Performs MNNVL Allreduce + Residual + RMSNorm.

trtllm_mnnvl_fused_allreduce_add_rmsnorm_quant(...)

Perform MNNVL AllReduce + Residual + RMSNorm + FP8/NVFP4 quantization.

trtllm_mnnvl_fused_allreduce_rmsnorm(...)

Performs MNNVL TwoShot Allreduce + RMSNorm.

mpi_barrier()

MNNVL A2A (Throughput Backend)

moe_a2a_initialize(workspace, ep_rank, ...)

Initialize the MoE all-to-all workspace and return a metainfo tensor.

moe_a2a_dispatch(token_selected_experts, ...)

Dispatch tokens and payloads to their target expert ranks.

moe_a2a_combine(payload, local_num_tokens, ...)

Combine per-expert outputs back to the originating ranks.

moe_a2a_sanitize_expert_ids(expert_ids, ...)

Replace expert IDs not owned by this rank with invalid_expert_id.

moe_a2a_get_workspace_size_per_rank(ep_size, ...)

Compute the per-rank workspace size for the MoE all-to-all primitive.

moe_a2a_wrap_payload_tensor_in_workspace(...)

Wrap a slice of the shared workspace as a typed tensor view.

class flashinfer.comm.MoeAlltoAll(mapping: Mapping, max_num_tokens: int, top_k: int, num_experts: int, workspace_size_per_rank: int = None, hidden_size: int = None, mnnvl_config: MnnvlConfig | None = None)

Bases: object

Manages MoE All-to-All operations with proper workspace allocation and synchronization.

This class provides the throughput-optimized backend that supports multiple payloads per collective operation, explicit dispatch/combine phases, and workspace-backed tensors.

Example

>>> moe_a2a = MoeAlltoAll(mapping, max_num_tokens=2048, top_k=2, num_experts=8)
>>> recv = moe_a2a.dispatch(experts, [hidden, ids, scales], batch_size)
>>> output = moe_a2a.combine(processed, batch_size)
__init__(mapping: Mapping, max_num_tokens: int, top_k: int, num_experts: int, workspace_size_per_rank: int = None, hidden_size: int = None, mnnvl_config: MnnvlConfig | None = None)

Initialize MoeAlltoAll and allocate the shared workspace.

Parameters:
  • mapping (Mapping) – Mapping object describing the parallel layout (must expose moe_ep_rank and moe_ep_size).

  • max_num_tokens (int) – Maximum number of tokens this rank will dispatch in any single call.

  • top_k (int) – Number of experts assigned per token.

  • num_experts (int) – Total number of experts (across all ranks).

  • workspace_size_per_rank (int, optional) – Pre-computed workspace size in bytes per rank. When None, hidden_size must be provided and the workspace is sized via get_moe_workspace_size_per_rank().

  • hidden_size (int, optional) – Hidden dimension size, used to derive workspace_size_per_rank when the latter is omitted.

  • mnnvl_config (MnnvlConfig, optional) – Optional configuration for the underlying MNNVL communication backend.

combine(payload: Tensor, runtime_max_tokens_per_rank: int, payload_in_workspace: bool = False) Tensor

Run the MoE all-to-all combine phase.

Parameters:
  • payload (torch.Tensor) – [ep_size, runtime_max_tokens_per_rank, elements_per_token] output payload to scatter back to source ranks.

  • runtime_max_tokens_per_rank (int) – Maximum tokens per rank in this batch (same value passed to dispatch()).

  • payload_in_workspace (bool) – True if payload is already a workspace-backed view (skips the staging copy). Defaults to False.

Returns:

[local_num_tokens, elements_per_token] combined tensor.

Return type:

torch.Tensor

dispatch(token_selected_experts: Tensor, input_payloads: list[Tensor], runtime_max_tokens_per_rank: int, invalid_token_expert_id: int | None = None, expert_id_payload_index: int | None = None) list[Tensor]

Run the MoE all-to-all dispatch phase.

Parameters:
  • token_selected_experts (torch.Tensor) – [local_num_tokens, top_k] int32 tensor of expert assignments.

  • input_payloads (list[torch.Tensor]) – Per-token payload tensors, each shaped [local_num_tokens, *].

  • runtime_max_tokens_per_rank (int) – Maximum tokens per rank in this batch. Must be <= max_num_tokens used at construction.

  • invalid_token_expert_id (int, optional) – If supplied, expert IDs not owned by the current rank are rewritten to this value. Requires expert_id_payload_index.

  • expert_id_payload_index (int, optional) – Index into input_payloads that holds the expert IDs to sanitize. Required when invalid_token_expert_id is set.

Returns:

Workspace-backed receive tensors, one per input_payloads entry, each shaped [ep_size, runtime_max_tokens_per_rank, *].

Return type:

list[torch.Tensor]

get_combine_payload_tensor_in_workspace(runtime_max_tokens_per_rank: int, hidden_size: int, dtype: dtype) Tensor

Return a workspace-backed view to use as the combine payload.

Zero-copy variant of combine(): experts can write directly into the returned tensor and call combine() with payload_in_workspace=True. Must be called after a successful dispatch() and before combine().

Parameters:
  • runtime_max_tokens_per_rank (int) – Maximum tokens per rank in this batch.

  • hidden_size (int) – Hidden dimension size.

  • dtype (torch.dtype) – Element dtype of the resulting view.

Returns:

[ep_size, runtime_max_tokens_per_rank, hidden_size] workspace-backed tensor.

Return type:

torch.Tensor

Raises:

RuntimeError – If called before a successful dispatch().

static get_moe_workspace_size_per_rank(ep_size: int, top_k: int, max_num_tokens: int, hidden_size: int, extra_payload_bytes_per_token: int = 0) int

Compute the per-rank workspace size for the MoE all-to-all primitive.

Convenience wrapper around moe_a2a_get_workspace_size_per_rank() that derives the dispatch / combine payload sizes from hidden_size and top_k assuming 16-bit hidden states. For a tighter bound on quantized models use moe_a2a_get_workspace_size_per_rank() directly.

Parameters:
  • ep_size (int) – Total expert-parallel world size.

  • top_k (int) – Number of experts assigned per token.

  • max_num_tokens (int) – Maximum number of tokens across all ranks.

  • hidden_size (int) – Hidden dimension size.

  • extra_payload_bytes_per_token (int) – Extra payload bytes per token to reserve (e.g. for quantization scales). Defaults to 0.

Returns:

Required workspace size per rank, in bytes.

Return type:

int

DCP All-to-All (Context-Parallel Attention Reduction)

decode_cp_a2a_workspace_size(cp_size)

Return the workspace size (in bytes) per rank for the given CP group size.

decode_cp_a2a_allocate_mnnvl_workspace(...)

Allocate an MNNVL-backed workspace of shape [cp_size, ws_elems_per_rank].

decode_cp_a2a_init_workspace(workspace, ...)

Initialize the workspace FIFO buffers (call once before the first alltoall).

decode_cp_a2a_alltoall(partial_o, ...[, ...])

Perform the DCP all-to-all exchange.

Mixed Communication

MixedCommOp(value[, names, module, ...])

Enumeration of mixed communication operation types.

MixedCommMode(value[, names, module, ...])

Enumeration of mixed communication execution modes.

MixedCommHandler(world_rank, world_size, ...)

An implementation for the combinations of all-reduce + all-gather and reduce-scatter + all-reduce.

run_mixed_comm(op, handler, x_in[, x_out, mode])

Execute a mixed communication operation.