flashinfer.comm.mixed_comm.MixedCommHandler

class flashinfer.comm.mixed_comm.MixedCommHandler(world_rank: int, world_size: int, local_rank: int, local_size: int, inter_rank: int, inter_size: int, local_tp_size: int | None, local_dp_size: int | None, inter_tp_size: int | None, inter_dp_size: int | None, dtype: dtype, device: device, grid_size: int | None = None, max_block_size: int | None = None, min_block_size: int = 256, min_num_steps: int = 4, ib_enable_ibgda: bool = True, should_init_nvshmem: bool = True, use_autotune: bool = True)

An implementation for the combinations of all-reduce + all-gather and reduce-scatter + all-reduce. The fused kernels use virtual memory for intra-node communication and nvshmem for inter-node communication. Currently, only float16 and bfloat16 data types are supported. Note: An active torch.distributed process group should be initialized before creating an instance of this class.

Parameters:
  • world_rank (int) – The world rank of the current process (local_rank + inter_rank * local_size).

  • world_size (int) – The total number of processes in the distributed group (inter_size * local_size).

  • local_rank (int) – The local rank of the current process (local_tp_rank + local_dp_rank * local_tp_size).

  • local_size (int) – The total number of processes in the current node (local_dp_size * local_tp_size).

  • inter_rank (int) – The index of the current node (inter_tp_rank + inter_dp_rank * inter_tp_size).

  • inter_size (int) – The total number of nodes in the distributed group (inter_dp_size * inter_tp_size).

  • local_tp_size (int | None) – TP size in the intra-node group. Use the default value if None is provided.

  • local_dp_size (int | None) – DP size in the intra-node group. Use the default value if None is provided.

  • inter_tp_size (int | None) – TP size in the inter-node group. Use the default value if None is provided.

  • inter_dp_size (int | None) – DP size in the inter-node group. Use the default value if None is provided.

  • dtype (torch.dtype) – The data type.

  • device (torch.device) – The device on which the tensors are located.

  • grid_size (int | None, optional) – The number of CTAs per GPU. The default behavior is to use the number of SMs.

  • max_block_size (int | None, optional) – The maximum limit of block size. The default behavior is not to set a limit.

  • min_block_size (int, optional) – The minimum block size if using multiple steps.

  • min_num_steps (int, optional) – The minimum number of steps if the maximum possible value of block size is chosen.

  • ib_enable_ibgda (bool, optional) – Whether to enable IBGDA.

  • should_init_nvshmem (bool, optional) – Whether to initialize nvshmem.

Raises:

RuntimeError – If nvshmem fails to initialize.

__init__(world_rank: int, world_size: int, local_rank: int, local_size: int, inter_rank: int, inter_size: int, local_tp_size: int | None, local_dp_size: int | None, inter_tp_size: int | None, inter_dp_size: int | None, dtype: dtype, device: device, grid_size: int | None = None, max_block_size: int | None = None, min_block_size: int = 256, min_num_steps: int = 4, ib_enable_ibgda: bool = True, should_init_nvshmem: bool = True, use_autotune: bool = True)

Initialize the handler, set up virtual memory, nvshmem, and run autotune.

Methods

__init__(world_rank, world_size, local_rank, ...)

Initialize the handler, set up virtual memory, nvshmem, and run autotune.

get_valid_mode_list()

Return the list of valid execution modes for the current parallel topology.

get_valid_op_list()

Return the list of valid communication operations for the current parallel topology.

init_nvshmem()

Initialize nvshmem for inter-node communication.

init_virtual_memory()

Initialize CUDA virtual memory for intra-node communication.

run_autotune()

Profile all valid (op, mode) combinations and populate the autotune map.

select_autotune_mode(op, x_in)

Select the best execution mode for the given operation and input tensor.

shutdown()

Tear down all communication resources.