flashinfer.comm.mnnvl.McastGPUBuffer¶
- class flashinfer.comm.mnnvl.McastGPUBuffer(buf_size: int, group_size: int, group_rank: int, device: device, mn_nvlink: bool = True)¶
Wrapper class for McastDeviceMemory to facilitate PyTorch tensor creation. It manages a buffer accessible via unicast or multicast for multi-node communication.
Python port of McastGPUBuffer from TensorRT-LLM
- __init__(buf_size: int, group_size: int, group_rank: int, device: device, mn_nvlink: bool = True)¶
Constructor for McastGpuBuffer.
- Parameters:
buf_size – The total size of the buffer in bytes
group_size – The number of ranks in the communication group
group_rank – The rank of the local process within the group
device – The CUDA device for buffer allocation
mn_nvlink – Flag indicating if multi-node NVLink is used
Methods
__init__
(buf_size, group_size, group_rank, ...)Constructor for McastGpuBuffer.
get_buffer_ptrs_dev
()Get the buffer pointers device array
get_mc_buffer
(sizes, dtype[, storage_offset])Returns a PyTorch tensor view of the multicast buffer portion.
get_multicast_ptr
()Get the raw multicast pointer
lamport_initialize
(rank, dtype)