flashinfer.comm.moe_a2a_get_workspace_size_per_rank

flashinfer.comm.moe_a2a_get_workspace_size_per_rank(ep_size: int, max_num_tokens: int, total_dispatch_payload_size_per_token: int, combine_payload_size_per_token: int)

Compute the per-rank workspace size for the MoE all-to-all primitive.

Parameters:
  • ep_size (int) – Total expert-parallel world size.

  • max_num_tokens (int) – Maximum number of tokens across all ranks.

  • total_dispatch_payload_size_per_token (int) – Sum (in bytes) of all per-token payloads sent during the dispatch phase.

  • combine_payload_size_per_token (int) – Per-token payload size (in bytes) sent back during the combine phase.

Returns:

Required workspace size per rank, in bytes.

Return type:

int