flashinfer.comm.decode_cp_a2a_allocate_mnnvl_workspace¶
- flashinfer.comm.decode_cp_a2a_allocate_mnnvl_workspace(mapping: Mapping, *, mnnvl_config: MnnvlConfig | None = None) Tensor¶
Allocate an MNNVL-backed workspace of shape
[cp_size, ws_elems_per_rank].The DCP A2A kernel requires a single unified VA spanning all CP ranks (see module docstring), so workspace allocation must go through MNNVL fabric memory. This function is the only supported allocator.
After allocation, call
decode_cp_a2a_init_workspace()followed by a cross-rank barrier before the firstdecode_cp_a2a_alltoall()call.- Parameters:
mapping (Mapping) – Mapping object for MNNVL allocation. Carries
cp_sizeandcp_rank. The communicator is split usingmapping.pp_rank,mapping.cp_rank, andmapping.tp_rank.mnnvl_config (MnnvlConfig, optional) – Configuration for the MNNVL communication backend. Required when using MNNVL with
torch.distributed(passMnnvlConfig(comm_backend=TorchDistBackend(group))).
- Returns:
torch.int64tensor of shape[cp_size, ws_elems_per_rank].- Return type:
torch.Tensor