flashinfer.comm.moe_a2a_combine¶
- flashinfer.comm.moe_a2a_combine(payload: Tensor, local_num_tokens: int, workspace: Tensor, metainfo: Tensor, runtime_max_tokens_per_rank: int, ep_rank: int, ep_size: int, top_k: int, combine_payload_offset: int, payload_in_workspace: bool = False, output_dtype: dtype | None = None, output_scales: Tensor | None = None, sf_layout: SfLayout = SfLayout.layout_linear) Tensor¶
Combine per-expert outputs back to the originating ranks.
Inverse of
moe_a2a_dispatch(): scatters the rank-local expert output rows back to the ranks that supplied the original tokens.- Parameters:
payload (torch.Tensor) – Output payload to send back to the source ranks. Shape
[ep_size, runtime_max_tokens_per_rank, *]regardless ofpayload_in_workspace: in both cases the payload holds the per-expert-rank outputs to be combined back to the source ranks. Only the backing memory differs (caller-supplied vs. workspace-backed view produced byMoeAlltoAll.get_combine_payload_tensor_in_workspace()).local_num_tokens (int) – Number of tokens originally dispatched from this rank.
workspace (torch.Tensor) – Shared workspace tensor (same one passed to dispatch).
metainfo (torch.Tensor) – Metainfo tensor returned by
moe_a2a_initialize().runtime_max_tokens_per_rank (int) – Same value passed to
moe_a2a_dispatch().ep_rank (int) – Current expert-parallel rank.
ep_size (int) – Total expert-parallel world size.
top_k (int) – Number of experts assigned per token.
combine_payload_offset (int) – Offset returned by
moe_a2a_dispatch().payload_in_workspace (bool) –
Trueifpayloadis already a workspace-backed view (skips the staging copy). Defaults toFalse.
- Returns:
[local_num_tokens, *]tensor with the combined outputs.- Return type:
torch.Tensor