flashinfer.comm.moe_a2a_dispatch¶
- flashinfer.comm.moe_a2a_dispatch(token_selected_experts: Tensor, input_payloads: list[Tensor], workspace: Tensor, metainfo: Tensor, runtime_max_tokens_per_rank: int, ep_rank: int, ep_size: int, top_k: int, num_experts: int)¶
Dispatch tokens and payloads to their target expert ranks.
- Parameters:
token_selected_experts (torch.Tensor) –
[local_num_tokens, top_k]int32tensor of expert assignments.input_payloads (list[torch.Tensor]) – Per-token payload tensors, each shaped
[local_num_tokens, *].workspace (torch.Tensor) –
[ep_size, size_per_rank]shared workspace.metainfo (torch.Tensor) – Metainfo tensor returned by
moe_a2a_initialize().runtime_max_tokens_per_rank (int) – Maximum tokens per rank for this batch (must be
<=themax_num_tokensused at initialize time).ep_rank (int) – Current expert-parallel rank.
ep_size (int) – Total expert-parallel world size.
top_k (int) – Number of experts assigned per token.
num_experts (int) – Total number of experts.
- Returns:
(output_payloads, combine_payload_offset).output_payloadsis a list of workspace-backed views, one perinput_payloadsentry, that contains the data routed to this rank.combine_payload_offsetis the workspace offset reserved for the matchingmoe_a2a_combine()call.- Return type:
Tuple[list[torch.Tensor], int]