flashinfer.comm¶
This module provides communication primitives and utilities for distributed computing, including CUDA IPC, AllReduce operations, and memory management utilities.
CUDA IPC Utilities¶
|
|
|
Allocate a buffer and share it across the process group via CUDA IPC. |
|
Free a shared buffer previously created by |
DLPack Utilities¶
|
Pack a strided device allocation as a PyTorch tensor view. |
Mapping Utilities¶
|
A node with 8 GPUs, tp_size = 4, cp_size = 1, pp_size = 2 |
TensorRT-LLM AllReduce¶
Types and Enums¶
Core Operations¶
|
Parameters: - allreduce_in: the input tensor. [token_num, hidden_dim] - world_size: the size of the process group. - world_rank: the rank of the current process. - token_num: the number of tokens in the sequence. - hidden_dim: the dimension of the hidden states. - workspace_ptrs: the workspace pointers. - launch_with_pdl: whether to launch with pdl. - use_oneshot: whether to use oneshot. If None, internal heuristics will be used. - trigger_completion_at_end: whether to trigger completion at the end. - fp32_acc: whether to use fp32 accumulation. - pattern_code: the pattern code. - allreduce_out: the output tensor. [token_num, hidden_dim] - residual_in: the residual input tensor. [token_num, hidden_dim] - residual_out: the residual output tensor. [token_num, hidden_dim] - norm_out: the norm output tensor. [token_num, hidden_dim] - quant_out: the quant output tensor. [token_num, hidden_dim] - scale_out: the scale output tensor. Initialization referece: tests/comm/test_trtllm_allreduce_fusion.py - rms_gamma: the rms gamma tensor. [hidden_dim] - rms_eps: the rms epsilon value. - scale_factor: the scale factor. For cudaGraphs safety, it should be a tensor. - layout_code: the layout code. - metadata: optional workspace metadata dict from create_ipc_workspace_for_all_reduce_fusion. If provided, validates that token_num <= max_token_num, world_size == tp_size, and hidden_dim == workspace hidden_dim. Raises ValueError if validation fails. - block_quant_group_size: group size (in elements along hidden_dim) for per-token-group block-wise FP8 quantization patterns (e.g. |
|
Parameters: - inp: the input tensor. |
|
Parameters: - world_size: the size of the process group. - world_rank: the rank of the current process. - token_num: the number of tokens in the sequence. - hidden_dim: the dimension of the hidden states. - workspace_ptrs: the workspace pointers. - launch_with_pdl: whether to launch with pdl. - residual_in: the residual input tensor. [token_num, hidden_dim] - rms_gamma: the rms gamma tensor. [hidden_dim] - rms_eps: the rms epsilon value. - scale_factor: the scale factor. - moe_reduction_device_num_experts: the number of experts. - moe_reduction_scale_input: the scale input tensor. [token_num, hidden_dim] - moe_reduction_active_experts_token_input: the active experts token input tensor. [token_num, hidden_dim] - moe_reduction_token_input: the token input tensor. [token_num, hidden_dim] - layout_code: the layout code. - moe_allreduce_out: the moe allreduce output tensor. [token_num, hidden_dim] - residual_out: the residual output tensor. [token_num, hidden_dim] - norm_out: the norm output tensor. [token_num, hidden_dim] - quant_out: the quant output tensor. [token_num // 4, hidden_dim], fp16/bf16 -> fp4 - scale_out: the scale output tensor. Initialization referece: tests/comm/test_trtllm_moe_allreduce_fusion.py - weight_bias: bias added to rms_gamma before scaling. None or 0.0 -> standard RMSNorm (out = gamma * x * rsqrt(...)). 1.0 -> Gemma / Qwen3.5 RMSNorm (out = (1 + gamma) * x * rsqrt(...)). |
|
Parameters: - allreduce_in: the input tensor. [token_num, top_k, hidden_dim] - residual_in: the residual input tensor. [token_num, hidden_dim] - norm_weight: the norm weight tensor. [hidden_dim] - expanded_idx_to_permuted_idx: the expanded index to permuted index tensor. [token_num, top_k] - norm_out: the norm output tensor. [token_num, hidden_dim] - residual_out: the residual output tensor. [token_num, hidden_dim] - quant_out: the quant output tensor. [token_num // 4, hidden_dim], fp16/bf16 -> fp4 - scale_out: the scale output tensor. [token_num // SF_VEC_SIZE, hidden_dim], fp16/bf16 -> fp4 - workspace_ptrs: the workspace pointers. - launch_with_pdl: whether to launch with pdl. - world_rank: the rank of the current process. - world_size: the size of the process group. - eps: the epsilon value. - shared_expert_output: the shared expert output tensor. [token_num, hidden_dim] - expert_scale_factor: the expert scale factor tensor. [token_num, top_k] - routed_scaling_factor: the routed scaling factor. - weight_bias: bias added to rms_gamma before scaling. None or 0.0 -> standard RMSNorm (out = gamma * x * rsqrt(...)). 1.0 -> Gemma / Qwen3.5 RMSNorm (out = (1 + gamma) * x * rsqrt(...)). |
Workspace Management¶
Parameters: - rank: the rank of the current process. |
|
Parameters: - tp_rank: the rank of the current process. |
|
Destroy a workspace created by trtllm_create_ipc_workspace_for_all_reduce. |
|
Destroy a workspace created by trtllm_create_ipc_workspace_for_all_reduce_fusion. |
Initialization and Utilities¶
|
Initialize a single Lamport-style buffer to negative zero. |
|
Initialize three Lamport buffers to negative zero. |
Compute the padded size (rows times columns) of the FP4 swizzled layout. |
Unified AllReduce Fusion API¶
|
AllReduce + RMSNorm fusion operation, with optional FP8/NVFP4 quantization for supported backends. |
|
Create workspace for AllReduce fusion operations. |
|
Base class for AllReduce fusion workspaces. |
|
TensorRT-LLM workspace for AllReduce fusion. |
|
vLLM AllReduce¶
|
Perform an out-of-place all-reduce via the vLLM custom kernel. |
|
Release the resources held by a vLLM custom all-reduce handle. |
|
Initialize the vLLM custom all-reduce backend. |
|
Register a peer's IPC-shared buffer with the local all-reduce handle. |
|
Register graph-capture buffers across the all-reduce world. |
Return IPC metadata for graph-capture buffers. |
|
Return the size of the vLLM all-reduce metadata structure in bytes. |
MNNVL (Multi-Node NVLink)¶
Core Classes¶
|
|
|
Wrapper class for SymmDeviceMemory to facilitate PyTorch tensor creation. |
TensorRT-LLM MNNVL AllReduce¶
|
Deprecated pointer-based MNNVL all-reduce API. |
|
Perform an MNNVL all-reduce sum across tensor-parallel ranks. |
Performs MNNVL Allreduce + Residual + RMSNorm. |
|
Perform MNNVL AllReduce + Residual + RMSNorm + FP8/NVFP4 quantization. |
|
Performs MNNVL TwoShot Allreduce + RMSNorm. |
|
MNNVL A2A (Throughput Backend)¶
|
Initialize the MoE all-to-all workspace and return a metainfo tensor. |
|
Dispatch tokens and payloads to their target expert ranks. |
|
Combine per-expert outputs back to the originating ranks. |
|
Replace expert IDs not owned by this rank with |
|
Compute the per-rank workspace size for the MoE all-to-all primitive. |
Wrap a slice of the shared workspace as a typed tensor view. |
- class flashinfer.comm.MoeAlltoAll(mapping: Mapping, max_num_tokens: int, top_k: int, num_experts: int, workspace_size_per_rank: int = None, hidden_size: int = None, mnnvl_config: MnnvlConfig | None = None)¶
Bases:
objectManages MoE All-to-All operations with proper workspace allocation and synchronization.
This class provides the throughput-optimized backend that supports multiple payloads per collective operation, explicit dispatch/combine phases, and workspace-backed tensors.
Example
>>> moe_a2a = MoeAlltoAll(mapping, max_num_tokens=2048, top_k=2, num_experts=8) >>> recv = moe_a2a.dispatch(experts, [hidden, ids, scales], batch_size) >>> output = moe_a2a.combine(processed, batch_size)
- __init__(mapping: Mapping, max_num_tokens: int, top_k: int, num_experts: int, workspace_size_per_rank: int = None, hidden_size: int = None, mnnvl_config: MnnvlConfig | None = None)¶
Initialize
MoeAlltoAlland allocate the shared workspace.- Parameters:
mapping (Mapping) – Mapping object describing the parallel layout (must expose
moe_ep_rankandmoe_ep_size).max_num_tokens (int) – Maximum number of tokens this rank will dispatch in any single call.
top_k (int) – Number of experts assigned per token.
num_experts (int) – Total number of experts (across all ranks).
workspace_size_per_rank (int, optional) – Pre-computed workspace size in bytes per rank. When
None,hidden_sizemust be provided and the workspace is sized viaget_moe_workspace_size_per_rank().hidden_size (int, optional) – Hidden dimension size, used to derive
workspace_size_per_rankwhen the latter is omitted.mnnvl_config (MnnvlConfig, optional) – Optional configuration for the underlying MNNVL communication backend.
- combine(payload: Tensor, runtime_max_tokens_per_rank: int, payload_in_workspace: bool = False) Tensor¶
Run the MoE all-to-all combine phase.
- Parameters:
payload (torch.Tensor) –
[ep_size, runtime_max_tokens_per_rank, elements_per_token]output payload to scatter back to source ranks.runtime_max_tokens_per_rank (int) – Maximum tokens per rank in this batch (same value passed to
dispatch()).payload_in_workspace (bool) –
Trueifpayloadis already a workspace-backed view (skips the staging copy). Defaults toFalse.
- Returns:
[local_num_tokens, elements_per_token]combined tensor.- Return type:
torch.Tensor
- dispatch(token_selected_experts: Tensor, input_payloads: list[Tensor], runtime_max_tokens_per_rank: int, invalid_token_expert_id: int | None = None, expert_id_payload_index: int | None = None) list[Tensor]¶
Run the MoE all-to-all dispatch phase.
- Parameters:
token_selected_experts (torch.Tensor) –
[local_num_tokens, top_k]int32tensor of expert assignments.input_payloads (list[torch.Tensor]) – Per-token payload tensors, each shaped
[local_num_tokens, *].runtime_max_tokens_per_rank (int) – Maximum tokens per rank in this batch. Must be
<=max_num_tokensused at construction.invalid_token_expert_id (int, optional) – If supplied, expert IDs not owned by the current rank are rewritten to this value. Requires
expert_id_payload_index.expert_id_payload_index (int, optional) – Index into
input_payloadsthat holds the expert IDs to sanitize. Required wheninvalid_token_expert_idis set.
- Returns:
Workspace-backed receive tensors, one per
input_payloadsentry, each shaped[ep_size, runtime_max_tokens_per_rank, *].- Return type:
list[torch.Tensor]
- get_combine_payload_tensor_in_workspace(runtime_max_tokens_per_rank: int, hidden_size: int, dtype: dtype) Tensor¶
Return a workspace-backed view to use as the combine payload.
Zero-copy variant of
combine(): experts can write directly into the returned tensor and callcombine()withpayload_in_workspace=True. Must be called after a successfuldispatch()and beforecombine().- Parameters:
runtime_max_tokens_per_rank (int) – Maximum tokens per rank in this batch.
hidden_size (int) – Hidden dimension size.
dtype (torch.dtype) – Element dtype of the resulting view.
- Returns:
[ep_size, runtime_max_tokens_per_rank, hidden_size]workspace-backed tensor.- Return type:
torch.Tensor
- Raises:
RuntimeError – If called before a successful
dispatch().
- static get_moe_workspace_size_per_rank(ep_size: int, top_k: int, max_num_tokens: int, hidden_size: int, extra_payload_bytes_per_token: int = 0) int¶
Compute the per-rank workspace size for the MoE all-to-all primitive.
Convenience wrapper around
moe_a2a_get_workspace_size_per_rank()that derives the dispatch / combine payload sizes fromhidden_sizeandtop_kassuming 16-bit hidden states. For a tighter bound on quantized models usemoe_a2a_get_workspace_size_per_rank()directly.- Parameters:
ep_size (int) – Total expert-parallel world size.
top_k (int) – Number of experts assigned per token.
max_num_tokens (int) – Maximum number of tokens across all ranks.
hidden_size (int) – Hidden dimension size.
extra_payload_bytes_per_token (int) – Extra payload bytes per token to reserve (e.g. for quantization scales). Defaults to
0.
- Returns:
Required workspace size per rank, in bytes.
- Return type:
int
DCP All-to-All (Context-Parallel Attention Reduction)¶
|
Return the workspace size (in bytes) per rank for the given CP group size. |
Allocate an MNNVL-backed workspace of shape |
|
|
Initialize the workspace FIFO buffers (call once before the first alltoall). |
|
Perform the DCP all-to-all exchange. |
Mixed Communication¶
|
Enumeration of mixed communication operation types. |
|
Enumeration of mixed communication execution modes. |
|
An implementation for the combinations of all-reduce + all-gather and reduce-scatter + all-reduce. |
|
Execute a mixed communication operation. |