flashinfer.comm.decode_cp_a2a_allocate_mnnvl_workspace

flashinfer.comm.decode_cp_a2a_allocate_mnnvl_workspace(mapping: Mapping, *, mnnvl_config: MnnvlConfig | None = None) Tensor

Allocate an MNNVL-backed workspace of shape [cp_size, ws_elems_per_rank].

The DCP A2A kernel requires a single unified VA spanning all CP ranks (see module docstring), so workspace allocation must go through MNNVL fabric memory. This function is the only supported allocator.

After allocation, call decode_cp_a2a_init_workspace() followed by a cross-rank barrier before the first decode_cp_a2a_alltoall() call.

Parameters:
  • mapping (Mapping) – Mapping object for MNNVL allocation. Carries cp_size and cp_rank. The communicator is split using mapping.pp_rank, mapping.cp_rank, and mapping.tp_rank.

  • mnnvl_config (MnnvlConfig, optional) – Configuration for the MNNVL communication backend. Required when using MNNVL with torch.distributed (pass MnnvlConfig(comm_backend=TorchDistBackend(group))).

Returns:

torch.int64 tensor of shape [cp_size, ws_elems_per_rank].

Return type:

torch.Tensor