flashinfer.comm.mixed_comm.MixedCommHandler¶
- class flashinfer.comm.mixed_comm.MixedCommHandler(world_rank: int, world_size: int, local_rank: int, local_size: int, inter_rank: int, inter_size: int, local_tp_size: int | None, local_dp_size: int | None, inter_tp_size: int | None, inter_dp_size: int | None, dtype: dtype, device: device, grid_size: int | None = None, max_block_size: int | None = None, min_block_size: int = 256, min_num_steps: int = 4, ib_enable_ibgda: bool = True, should_init_nvshmem: bool = True, use_autotune: bool = True)¶
An implementation for the combinations of all-reduce + all-gather and reduce-scatter + all-reduce. The fused kernels use virtual memory for intra-node communication and nvshmem for inter-node communication. Currently, only float16 and bfloat16 data types are supported. Note: An active torch.distributed process group should be initialized before creating an instance of this class.
- Parameters:
world_rank (int) – The world rank of the current process (local_rank + inter_rank * local_size).
world_size (int) – The total number of processes in the distributed group (inter_size * local_size).
local_rank (int) – The local rank of the current process (local_tp_rank + local_dp_rank * local_tp_size).
local_size (int) – The total number of processes in the current node (local_dp_size * local_tp_size).
inter_rank (int) – The index of the current node (inter_tp_rank + inter_dp_rank * inter_tp_size).
inter_size (int) – The total number of nodes in the distributed group (inter_dp_size * inter_tp_size).
local_tp_size (int | None) – TP size in the intra-node group. Use the default value if None is provided.
local_dp_size (int | None) – DP size in the intra-node group. Use the default value if None is provided.
inter_tp_size (int | None) – TP size in the inter-node group. Use the default value if None is provided.
inter_dp_size (int | None) – DP size in the inter-node group. Use the default value if None is provided.
dtype (torch.dtype) – The data type.
device (torch.device) – The device on which the tensors are located.
grid_size (int | None, optional) – The number of CTAs per GPU. The default behavior is to use the number of SMs.
max_block_size (int | None, optional) – The maximum limit of block size. The default behavior is not to set a limit.
min_block_size (int, optional) – The minimum block size if using multiple steps.
min_num_steps (int, optional) – The minimum number of steps if the maximum possible value of block size is chosen.
ib_enable_ibgda (bool, optional) – Whether to enable IBGDA.
should_init_nvshmem (bool, optional) – Whether to initialize nvshmem.
- Raises:
RuntimeError – If nvshmem fails to initialize.
- __init__(world_rank: int, world_size: int, local_rank: int, local_size: int, inter_rank: int, inter_size: int, local_tp_size: int | None, local_dp_size: int | None, inter_tp_size: int | None, inter_dp_size: int | None, dtype: dtype, device: device, grid_size: int | None = None, max_block_size: int | None = None, min_block_size: int = 256, min_num_steps: int = 4, ib_enable_ibgda: bool = True, should_init_nvshmem: bool = True, use_autotune: bool = True)¶
Initialize the handler, set up virtual memory, nvshmem, and run autotune.
Methods
__init__(world_rank, world_size, local_rank, ...)Initialize the handler, set up virtual memory, nvshmem, and run autotune.
get_valid_mode_list()Return the list of valid execution modes for the current parallel topology.
get_valid_op_list()Return the list of valid communication operations for the current parallel topology.
init_nvshmem()Initialize nvshmem for inter-node communication.
init_virtual_memory()Initialize CUDA virtual memory for intra-node communication.
run_autotune()Profile all valid (op, mode) combinations and populate the autotune map.
select_autotune_mode(op, x_in)Select the best execution mode for the given operation and input tensor.
shutdown()Tear down all communication resources.