flashinfer.comm.MoeAlltoAll

class flashinfer.comm.MoeAlltoAll(mapping: Mapping, max_num_tokens: int, top_k: int, num_experts: int, workspace_size_per_rank: int = None, hidden_size: int = None, mnnvl_config: MnnvlConfig | None = None)

Manages MoE All-to-All operations with proper workspace allocation and synchronization.

This class provides the throughput-optimized backend that supports multiple payloads per collective operation, explicit dispatch/combine phases, and workspace-backed tensors.

Example

>>> moe_a2a = MoeAlltoAll(mapping, max_num_tokens=2048, top_k=2, num_experts=8)
>>> recv = moe_a2a.dispatch(experts, [hidden, ids, scales], batch_size)
>>> output = moe_a2a.combine(processed, batch_size)
__init__(mapping: Mapping, max_num_tokens: int, top_k: int, num_experts: int, workspace_size_per_rank: int = None, hidden_size: int = None, mnnvl_config: MnnvlConfig | None = None)

Initialize MoeAlltoAll with workspace allocation.

Parameters:
  • mapping – Mapping object containing rank information

  • max_num_tokens – Maximum number of tokens supported

  • top_k – Number of experts per token

  • num_experts – Total number of experts

  • workspace_size_per_rank – Size of workspace per rank in bytes, if None hidden_size must be provided

  • hidden_size – Hidden dimension size used when calculating the workspace size, if workspace_size_per_rank is not provided

  • mnnvl_config – Used to configure the communication backend for the MNNVL memory object

Methods

__init__(mapping, max_num_tokens, top_k, ...)

Initialize MoeAlltoAll with workspace allocation.

combine(payload, runtime_max_tokens_per_rank)

Perform MoE all-to-all combine operation.

dispatch(token_selected_experts, ...[, ...])

Perform MoE all-to-all dispatch operation.

get_combine_payload_tensor_in_workspace(...)

Get combine payload tensor backed by workspace (zero-copy).

get_moe_workspace_size_per_rank(ep_size, ...)

Convenience wrapper to calculate the workspace size per rank for the MoeAlltoAll operation.

get_workspace(workspace_size_per_rank, ...)