flashinfer.comm.decode_cp_a2a_alltoall¶
- flashinfer.comm.decode_cp_a2a_alltoall(partial_o: Tensor, softmax_stats: Tensor, workspace: Tensor, cp_rank: int, cp_size: int, enable_pdl: bool | None = None) tuple[Tensor, Tensor]¶
Perform the DCP all-to-all exchange.
Each rank sends its
partial_o[..., peer, :]slice to the corresponding peer and receives all peers’ contributions into the output tensors.- Parameters:
partial_o (torch.Tensor) –
[..., cp_size, D]halforbfloat16tensor.D * element_sizemust be 16-byte aligned.softmax_stats (torch.Tensor) –
[..., cp_size, S]float32tensor withS >= 2and even. Batch dimensions must matchpartial_o.workspace (torch.Tensor) –
[cp_size, ws_elems_per_rank]int64tensor fromdecode_cp_a2a_allocate_mnnvl_workspace(), already initialized.cp_rank (int) – This rank’s position in the CP group.
cp_size (int) – Context-parallel group size.
enable_pdl (bool, optional) – Enable Programmatic Dependent Launch (SM90+). Defaults to
Trueon SM90+ GPUs andFalseotherwise.
- Returns:
(partial_o_out, softmax_stats_out)with the same shapes and dtypes as the inputs. Each output contains the gathered data from all peers for this rank.- Return type:
Tuple[torch.Tensor, torch.Tensor]