flashinfer.comm.moe_a2a_dispatch

flashinfer.comm.moe_a2a_dispatch(token_selected_experts: Tensor, input_payloads: list[Tensor], workspace: Tensor, metainfo: Tensor, runtime_max_tokens_per_rank: int, ep_rank: int, ep_size: int, top_k: int, num_experts: int)

Dispatch tokens and payloads to their target expert ranks.

Parameters:
  • token_selected_experts (torch.Tensor) – [local_num_tokens, top_k] int32 tensor of expert assignments.

  • input_payloads (list[torch.Tensor]) – Per-token payload tensors, each shaped [local_num_tokens, *].

  • workspace (torch.Tensor) – [ep_size, size_per_rank] shared workspace.

  • metainfo (torch.Tensor) – Metainfo tensor returned by moe_a2a_initialize().

  • runtime_max_tokens_per_rank (int) – Maximum tokens per rank for this batch (must be <= the max_num_tokens used at initialize time).

  • ep_rank (int) – Current expert-parallel rank.

  • ep_size (int) – Total expert-parallel world size.

  • top_k (int) – Number of experts assigned per token.

  • num_experts (int) – Total number of experts.

Returns:

(output_payloads, combine_payload_offset). output_payloads is a list of workspace-backed views, one per input_payloads entry, that contains the data routed to this rank. combine_payload_offset is the workspace offset reserved for the matching moe_a2a_combine() call.

Return type:

Tuple[list[torch.Tensor], int]