flashinfer.comm.moe_a2a_get_workspace_size_per_rank¶
- flashinfer.comm.moe_a2a_get_workspace_size_per_rank(ep_size: int, max_num_tokens: int, total_dispatch_payload_size_per_token: int, combine_payload_size_per_token: int)¶
Compute the per-rank workspace size for the MoE all-to-all primitive.
- Parameters:
ep_size (int) – Total expert-parallel world size.
max_num_tokens (int) – Maximum number of tokens across all ranks.
total_dispatch_payload_size_per_token (int) – Sum (in bytes) of all per-token payloads sent during the dispatch phase.
combine_payload_size_per_token (int) – Per-token payload size (in bytes) sent back during the combine phase.
- Returns:
Required workspace size per rank, in bytes.
- Return type:
int