flashinfer.comm.moe_a2a_initialize

flashinfer.comm.moe_a2a_initialize(workspace: Tensor, ep_rank: int, ep_size: int, max_num_tokens: int)

Initialize the MoE all-to-all workspace and return a metainfo tensor.

The metainfo tensor encodes per-rank offsets and bookkeeping required by moe_a2a_dispatch() and moe_a2a_combine(); it must be passed back into those routines for the same workspace. moe_a2a_initialize is idempotent and must be called once per workspace allocation before any dispatch/combine.

Parameters:
  • workspace (torch.Tensor) – [ep_size, size_per_rank] shared workspace tensor.

  • ep_rank (int) – Current expert-parallel rank.

  • ep_size (int) – Total expert-parallel world size.

  • max_num_tokens (int) – Maximum number of tokens any rank may dispatch in a single call; used to size the metainfo allocation.

Returns:

Metainfo tensor opaque to callers; pass it to subsequent moe_a2a_* calls.

Return type:

torch.Tensor