flashinfer.comm.moe_a2a_get_workspace_size_per_rank

flashinfer.comm.moe_a2a_get_workspace_size_per_rank(ep_size: int, max_num_tokens: int, total_dispatch_payload_size_per_token: int, combine_payload_size_per_token: int)

Get the workspace size per rank for the MoeAlltoAll operation.

Parameters:
  • ep_size – Total expert parallel size

  • max_num_tokens – Maximum number of tokens across all ranks

  • total_dispatch_payload_size_per_token – The size of the payload per token in the dispatch phase. This should be the sum of all payloads.

  • combine_payload_size_per_token – The size of the payload per token in the combine phase.

Returns:

Size of the workspace per rank in bytes

Return type:

workspace_size_per_rank