flashinfer.comm.decode_cp_a2a_alltoall

flashinfer.comm.decode_cp_a2a_alltoall(partial_o: Tensor, softmax_stats: Tensor, workspace: Tensor, cp_rank: int, cp_size: int, enable_pdl: bool | None = None) tuple[Tensor, Tensor]

Perform the DCP all-to-all exchange.

Each rank sends its partial_o[..., peer, :] slice to the corresponding peer and receives all peers’ contributions into the output tensors.

Parameters:
  • partial_o (torch.Tensor) – [..., cp_size, D] half or bfloat16 tensor. D * element_size must be 16-byte aligned.

  • softmax_stats (torch.Tensor) – [..., cp_size, S] float32 tensor with S >= 2 and even. Batch dimensions must match partial_o.

  • workspace (torch.Tensor) – [cp_size, ws_elems_per_rank] int64 tensor from decode_cp_a2a_allocate_mnnvl_workspace(), already initialized.

  • cp_rank (int) – This rank’s position in the CP group.

  • cp_size (int) – Context-parallel group size.

  • enable_pdl (bool, optional) – Enable Programmatic Dependent Launch (SM90+). Defaults to True on SM90+ GPUs and False otherwise.

Returns:

(partial_o_out, softmax_stats_out) with the same shapes and dtypes as the inputs. Each output contains the gathered data from all peers for this rank.

Return type:

Tuple[torch.Tensor, torch.Tensor]