flashinfer.comm.moe_a2a_initialize¶
- flashinfer.comm.moe_a2a_initialize(workspace: Tensor, ep_rank: int, ep_size: int, max_num_tokens: int)¶
Initialize the MoE all-to-all workspace and return a metainfo tensor.
The metainfo tensor encodes per-rank offsets and bookkeeping required by
moe_a2a_dispatch()andmoe_a2a_combine(); it must be passed back into those routines for the same workspace.moe_a2a_initializeis idempotent and must be called once per workspace allocation before any dispatch/combine.- Parameters:
workspace (torch.Tensor) –
[ep_size, size_per_rank]shared workspace tensor.ep_rank (int) – Current expert-parallel rank.
ep_size (int) – Total expert-parallel world size.
max_num_tokens (int) – Maximum number of tokens any rank may dispatch in a single call; used to size the metainfo allocation.
- Returns:
Metainfo tensor opaque to callers; pass it to subsequent
moe_a2a_*calls.- Return type:
torch.Tensor