flashinfer.comm.moe_a2a_combine

flashinfer.comm.moe_a2a_combine(payload: Tensor, local_num_tokens: int, workspace: Tensor, metainfo: Tensor, runtime_max_tokens_per_rank: int, ep_rank: int, ep_size: int, top_k: int, combine_payload_offset: int, payload_in_workspace: bool = False, output_dtype: dtype | None = None, output_scales: Tensor | None = None, sf_layout: SfLayout = SfLayout.layout_linear) Tensor

Combine per-expert outputs back to the originating ranks.

Inverse of moe_a2a_dispatch(): scatters the rank-local expert output rows back to the ranks that supplied the original tokens.

Parameters:
  • payload (torch.Tensor) – Output payload to send back to the source ranks. Shape [ep_size, runtime_max_tokens_per_rank, *] regardless of payload_in_workspace: in both cases the payload holds the per-expert-rank outputs to be combined back to the source ranks. Only the backing memory differs (caller-supplied vs. workspace-backed view produced by MoeAlltoAll.get_combine_payload_tensor_in_workspace()).

  • local_num_tokens (int) – Number of tokens originally dispatched from this rank.

  • workspace (torch.Tensor) – Shared workspace tensor (same one passed to dispatch).

  • metainfo (torch.Tensor) – Metainfo tensor returned by moe_a2a_initialize().

  • runtime_max_tokens_per_rank (int) – Same value passed to moe_a2a_dispatch().

  • ep_rank (int) – Current expert-parallel rank.

  • ep_size (int) – Total expert-parallel world size.

  • top_k (int) – Number of experts assigned per token.

  • combine_payload_offset (int) – Offset returned by moe_a2a_dispatch().

  • payload_in_workspace (bool) – True if payload is already a workspace-backed view (skips the staging copy). Defaults to False.

Returns:

[local_num_tokens, *] tensor with the combined outputs.

Return type:

torch.Tensor