flashinfer.comm.decode_cp_a2a_init_workspace

flashinfer.comm.decode_cp_a2a_init_workspace(workspace: Tensor, cp_rank: int, cp_size: int) None

Initialize the workspace FIFO buffers (call once before the first alltoall).

Resets the FIFO buffers in the local workspace row (workspace[cp_rank]). This function is synchronous: when it returns, the GPU memset is guaranteed to have completed.

Important

With MNNVL workspaces, all ranks must complete decode_cp_a2a_init_workspace and execute a cross-rank barrier (e.g. dist.barrier(group)) before any rank calls decode_cp_a2a_alltoall(). Without the barrier, a rank may start writing to a peer’s FIFO before that peer has finished initializing - deadlock.

Parameters:
  • workspace (torch.Tensor) – [cp_size, ws_elems_per_rank] int64 tensor from decode_cp_a2a_allocate_mnnvl_workspace().

  • cp_rank (int) – This rank’s position in the CP group.

  • cp_size (int) – Context-parallel group size.