flashinfer.comm.MNNVLAllReduceFusionWorkspace¶
- class flashinfer.comm.MNNVLAllReduceFusionWorkspace(mapping: Mapping, max_num_tokens: int | None = None, hidden_dim: int | None = None, dtype: dtype | None = None, buffer_size_in_bytes: int | None = None, comm_backend: CommBackend | None = None)¶
- __init__(mapping: Mapping, max_num_tokens: int | None = None, hidden_dim: int | None = None, dtype: dtype | None = None, buffer_size_in_bytes: int | None = None, comm_backend: CommBackend | None = None)¶
Initialize the MNNVL Allreduce Fusion Workspace. The workspace will be allocated and initialized based on the provided problem size. If max_num_tokens is larger than the one-shot threshold, the workspace will be created according to the max of required one-shot size at threshold, or the required two-shot size. Note that the workspace is not bind to the given problem size. It can be reused for different problem size without reinitialization given the allocated size is sufficient.
If the buffer_size_in_bytes is provided, the workspace will be created according to the provided size. The user is expected to use the utility function get_required_buffer_size_bytes to calculate the required size. The actual allocation size may be larger due to alignment requirements. This covers the advanced used case, for example, the user may want to enforce oneshot strategy and ignore the heuristics.
Either max_num_tokens or buffer_size_in_bytes must be provided.
comm_backend will be used for creating the workspace and synchronization. If not provided, MPIBackend will be used which will use COMM_WORLD for synchronization.
- Parameters:
mapping – Mapping configuration containing rank info
max_num_tokens – The maximum number of tokens in the input tensor.
hidden_dim – The hidden dimension of the tensors to be reduced.
dtype – The data type of the tensors to be reduced.
buffer_size_in_bytes – The requested size in bytes for each lamport buffer. The actual allocation size may be larger due to alignment requirements. The actual usable size will be NUM_LAMPORT_BUFFERS * actual_buffer_size_per_lamport_buffer.
Methods
__init__(mapping[, max_num_tokens, ...])Initialize the MNNVL Allreduce Fusion Workspace.
destroy()Destroy workspace and free resources.
get_required_buffer_size_bytes(tp_size, ...)Calculate the required buffer size for a given problem size.
is_buffer_size_sufficient(tp_size, ...[, ...])Calculate the required buffer size for a given problem size.
Attributes
NUM_LAMPORT_BUFFERSbackendReturn backend name.
world_sizerank