flashinfer.comm.MoeAlltoAll¶
- class flashinfer.comm.MoeAlltoAll(mapping: Mapping, max_num_tokens: int, top_k: int, num_experts: int, workspace_size_per_rank: int = None, hidden_size: int = None, mnnvl_config: MnnvlConfig | None = None)¶
Manages MoE All-to-All operations with proper workspace allocation and synchronization.
This class provides the throughput-optimized backend that supports multiple payloads per collective operation, explicit dispatch/combine phases, and workspace-backed tensors.
Example
>>> moe_a2a = MoeAlltoAll(mapping, max_num_tokens=2048, top_k=2, num_experts=8) >>> recv = moe_a2a.dispatch(experts, [hidden, ids, scales], batch_size) >>> output = moe_a2a.combine(processed, batch_size)
- __init__(mapping: Mapping, max_num_tokens: int, top_k: int, num_experts: int, workspace_size_per_rank: int = None, hidden_size: int = None, mnnvl_config: MnnvlConfig | None = None)¶
Initialize MoeAlltoAll with workspace allocation.
- Parameters:
mapping – Mapping object containing rank information
max_num_tokens – Maximum number of tokens supported
top_k – Number of experts per token
num_experts – Total number of experts
workspace_size_per_rank – Size of workspace per rank in bytes, if None hidden_size must be provided
hidden_size – Hidden dimension size used when calculating the workspace size, if workspace_size_per_rank is not provided
mnnvl_config – Used to configure the communication backend for the MNNVL memory object
Methods
__init__(mapping, max_num_tokens, top_k, ...)Initialize MoeAlltoAll with workspace allocation.
combine(payload, runtime_max_tokens_per_rank)Perform MoE all-to-all combine operation.
dispatch(token_selected_experts, ...[, ...])Perform MoE all-to-all dispatch operation.
get_combine_payload_tensor_in_workspace(...)Get combine payload tensor backed by workspace (zero-copy).
get_moe_workspace_size_per_rank(ep_size, ...)Convenience wrapper to calculate the workspace size per rank for the MoeAlltoAll operation.
get_workspace(workspace_size_per_rank, ...)