flashinfer.comm.mnnvl.McastGPUBuffer

class flashinfer.comm.mnnvl.McastGPUBuffer(buf_size: int, group_size: int, group_rank: int, device: device, mn_nvlink: bool = True)

Wrapper class for McastDeviceMemory to facilitate PyTorch tensor creation. It manages a buffer accessible via unicast or multicast for multi-node communication.

Python port of McastGPUBuffer from TensorRT-LLM

__init__(buf_size: int, group_size: int, group_rank: int, device: device, mn_nvlink: bool = True)

Constructor for McastGpuBuffer.

Parameters:
  • buf_size – The total size of the buffer in bytes

  • group_size – The number of ranks in the communication group

  • group_rank – The rank of the local process within the group

  • device – The CUDA device for buffer allocation

  • mn_nvlink – Flag indicating if multi-node NVLink is used

Methods

__init__(buf_size, group_size, group_rank, ...)

Constructor for McastGpuBuffer.

get_buffer_ptrs_dev()

Get the buffer pointers device array

get_mc_buffer(sizes, dtype[, storage_offset])

Returns a PyTorch tensor view of the multicast buffer portion.

get_multicast_ptr()

Get the raw multicast pointer

lamport_initialize(rank, dtype)