flashinfer.comm.mnnvl.McastGPUBuffer

class flashinfer.comm.mnnvl.McastGPUBuffer(buf_size: int, group_size: int, group_rank: int, device: device, comm_backend_for_handle_transfer: CommBackend | None = None)

Wrapper class for SymmDeviceMemory to facilitate PyTorch tensor creation. It manages a buffer accessible via unicast or multicast for multi-node communication.

Python port of McastGPUBuffer from TensorRT-LLM

__init__(buf_size: int, group_size: int, group_rank: int, device: device, comm_backend_for_handle_transfer: CommBackend | None = None)

Constructor for McastGpuBuffer.

Parameters:
  • buf_size – The requested size of the buffer in bytes. The actual usable size may differ due to alignment requirements.

  • group_size – The number of ranks in the communication group

  • group_rank – The rank of the local process within the group

  • device – The CUDA device for buffer allocation

  • mn_nvlink – Flag indicating if multi-node NVLink is used

  • comm_backend_for_handle_transfer – Communication backend for handle transfer

Methods

__init__(buf_size, group_size, group_rank, ...)

Constructor for McastGpuBuffer.

get_buffer_ptrs_dev()

Get the buffer pointers device array

get_multicast_buffer(sizes, dtype[, ...])

Returns a PyTorch tensor view of the multicast buffer portion.

get_multicast_ptr()

Get the raw multicast pointer

get_unicast_buffer(sizes, dtype[, ...])

Returns a PyTorch tensor view of the unicast buffer portion.

get_unicast_ptr(rank)

Get the raw unicast pointer to a given rank

lamport_initialize(rank, dtype)