flashinfer.comm

This module provides communication primitives and utilities for distributed computing, including CUDA IPC, AllReduce operations, and memory management utilities.

CUDA IPC Utilities

CudaRTLibrary([so_file])

create_shared_buffer(size_in_bytes[, group])

Creates a shared buffer and returns a list of pointers representing the buffer on all processes in the group.

free_shared_buffer(pointers[, group])

Frees a shared buffer.

DLPack Utilities

pack_strided_memory(ptr, segment_size, ...)

Pack GPU memory into a PyTorch tensor with specified stride.

Mapping Utilities

Mapping([world_size, rank, gpus_per_node, ...])

A node with 8 GPUs, tp_size = 4, cp_size = 1, pp_size = 2

TensorRT-LLM AllReduce

Types and Enums

Core Operations

trtllm_allreduce_fusion(allreduce_in, ...[, ...])

Parameters: - allreduce_in: the input tensor. [token_num, hidden_dim] - world_size: the size of the process group. - world_rank: the rank of the current process. - token_num: the number of tokens in the sequence. - hidden_dim: the dimension of the hidden states. - workspace_ptrs: the workspace pointers. - launch_with_pdl: whether to launch with pdl. - use_oneshot: whether to use oneshot. If None, internal heuristics will be used. - trigger_completion_at_end: whether to trigger completion at the end. - fp32_acc: whether to use fp32 accumulation. - pattern_code: the pattern code. - allreduce_out: the output tensor. [token_num, hidden_dim] - residual_in: the residual input tensor. [token_num, hidden_dim] - residual_out: the residual output tensor. [token_num, hidden_dim] - norm_out: the norm output tensor. [token_num, hidden_dim] - quant_out: the quant output tensor. [token_num, hidden_dim] - scale_out: the scale output tensor. Initialization referece: tests/comm/test_trtllm_allreduce_fusion.py - rms_gamma: the rms gamma tensor. [hidden_dim] - rms_eps: the rms epsilon value. - scale_factor: the scale factor. For cudaGraphs safety, it should be a tensor. - layout_code: the layout code. - metadata: optional workspace metadata dict from create_ipc_workspace_for_all_reduce_fusion. If provided, validates that token_num <= max_token_num, world_size == tp_size, and hidden_dim == workspace hidden_dim. Raises ValueError if validation fails. - weight_bias: bias added to rms_gamma before scaling. None or 0.0 -> standard RMSNorm (out = gamma * x * rsqrt(...)). 1.0 -> Gemma / Qwen3.5 RMSNorm (out = (1 + gamma) * x * rsqrt(...)). Ignored for kAllReduce and quant-only patterns that don't apply RMSNorm.

trtllm_custom_all_reduce(inp, out, tp_size, ...)

Parameters: - inp: the input tensor.

trtllm_moe_allreduce_fusion(world_size, ...)

Parameters: - world_size: the size of the process group. - world_rank: the rank of the current process. - token_num: the number of tokens in the sequence. - hidden_dim: the dimension of the hidden states. - workspace_ptrs: the workspace pointers. - launch_with_pdl: whether to launch with pdl. - residual_in: the residual input tensor. [token_num, hidden_dim] - rms_gamma: the rms gamma tensor. [hidden_dim] - rms_eps: the rms epsilon value. - scale_factor: the scale factor. - moe_reduction_device_num_experts: the number of experts. - moe_reduction_scale_input: the scale input tensor. [token_num, hidden_dim] - moe_reduction_active_experts_token_input: the active experts token input tensor. [token_num, hidden_dim] - moe_reduction_token_input: the token input tensor. [token_num, hidden_dim] - layout_code: the layout code. - moe_allreduce_out: the moe allreduce output tensor. [token_num, hidden_dim] - residual_out: the residual output tensor. [token_num, hidden_dim] - norm_out: the norm output tensor. [token_num, hidden_dim] - quant_out: the quant output tensor. [token_num // 4, hidden_dim], fp16/bf16 -> fp4 - scale_out: the scale output tensor. Initialization referece: tests/comm/test_trtllm_moe_allreduce_fusion.py - weight_bias: bias added to rms_gamma before scaling. None or 0.0 -> standard RMSNorm (out = gamma * x * rsqrt(...)). 1.0 -> Gemma / Qwen3.5 RMSNorm (out = (1 + gamma) * x * rsqrt(...)).

trtllm_moe_finalize_allreduce_fusion(...[, ...])

Parameters: - allreduce_in: the input tensor. [token_num, top_k, hidden_dim] - residual_in: the residual input tensor. [token_num, hidden_dim] - norm_weight: the norm weight tensor. [hidden_dim] - expanded_idx_to_permuted_idx: the expanded index to permuted index tensor. [token_num, top_k] - norm_out: the norm output tensor. [token_num, hidden_dim] - residual_out: the residual output tensor. [token_num, hidden_dim] - quant_out: the quant output tensor. [token_num // 4, hidden_dim], fp16/bf16 -> fp4 - scale_out: the scale output tensor. [token_num // SF_VEC_SIZE, hidden_dim], fp16/bf16 -> fp4 - workspace_ptrs: the workspace pointers. - launch_with_pdl: whether to launch with pdl. - world_rank: the rank of the current process. - world_size: the size of the process group. - eps: the epsilon value. - shared_expert_output: the shared expert output tensor. [token_num, hidden_dim] - expert_scale_factor: the expert scale factor tensor. [token_num, top_k] - routed_scaling_factor: the routed scaling factor. - weight_bias: bias added to rms_gamma before scaling. None or 0.0 -> standard RMSNorm (out = gamma * x * rsqrt(...)). 1.0 -> Gemma / Qwen3.5 RMSNorm (out = (1 + gamma) * x * rsqrt(...)).

Workspace Management

trtllm_create_ipc_workspace_for_all_reduce(...)

Parameters: - rank: the rank of the current process.

trtllm_create_ipc_workspace_for_all_reduce_fusion(...)

Parameters: - tp_rank: the rank of the current process.

trtllm_destroy_ipc_workspace_for_all_reduce(...)

Destroy a workspace created by trtllm_create_ipc_workspace_for_all_reduce.

trtllm_destroy_ipc_workspace_for_all_reduce_fusion(...)

Destroy a workspace created by trtllm_create_ipc_workspace_for_all_reduce_fusion.

Initialization and Utilities

trtllm_lamport_initialize(buffer_ptr, size, ...)

trtllm_lamport_initialize_all(buffer_0_ptr, ...)

Initialize 3 lamport buffers by negative zero.

compute_fp4_swizzled_layout_sf_size(...)

Helper function to compute the padded size of the fp4 swizzled layout.

Unified AllReduce Fusion API

allreduce_fusion(input, workspace, pattern)

AllReduce + RMSNorm fusion operation.

create_allreduce_fusion_workspace([backend, ...])

Create workspace for AllReduce fusion operations.

AllReduceFusionWorkspace(world_size, rank)

Base class for AllReduce fusion workspaces.

TRTLLMAllReduceFusionWorkspace(tp_size, ...)

TensorRT-LLM workspace for AllReduce fusion.

MNNVLAllReduceFusionWorkspace(mapping[, ...])

vLLM AllReduce

vllm_all_reduce(fa, inp, out, reg_buffer, ...)

Performs an out-of-place all reduce.

vllm_dispose(fa)

vllm_init_custom_ar(ipc_tensors, rank_data, ...)

vllm_register_buffer(fa, fake_ipc_ptrs)

vllm_register_graph_buffers(fa, handles, offsets)

vllm_get_graph_buffer_ipc_meta(fa)

vllm_meta_size()

TensorRT-LLM MNNVL AllReduce

trtllm_mnnvl_all_reduce(inp, ...[, out])

Perform a multi-node NVLink all-reduce operation across multiple GPUs.

trtllm_mnnvl_allreduce(input, workspace, ...)

Perform a multi-node NVLink all-reduce operation across multiple GPUs.

trtllm_mnnvl_fused_allreduce_add_rmsnorm(...)

Performs MNNVL Allreduce + Residual + RMSNorm.

trtllm_mnnvl_fused_allreduce_rmsnorm(...)

Performs MNNVL TwoShot Allreduce + RMSNorm.

mpi_barrier()

MNNVL A2A (Throughput Backend)

MoeAlltoAll(mapping, max_num_tokens, top_k, ...)

Manages MoE All-to-All operations with proper workspace allocation and synchronization.

moe_a2a_initialize(workspace, ep_rank, ...)

moe_a2a_dispatch(token_selected_experts, ...)

Dispatch tokens and payloads to expert ranks.

moe_a2a_combine(payload, local_num_tokens, ...)

moe_a2a_sanitize_expert_ids(expert_ids, ...)

moe_a2a_get_workspace_size_per_rank(ep_size, ...)

Get the workspace size per rank for the MoeAlltoAll operation.

moe_a2a_wrap_payload_tensor_in_workspace(...)

Wrap an offset in the workspace into a tensor.