flashinfer.comm.TRTLLMAllReduceFusionWorkspace

class flashinfer.comm.TRTLLMAllReduceFusionWorkspace(tp_size: int, tp_rank: int, max_token_num: int, hidden_dim: int, dtype: dtype = torch.float16, comm_backend: CommBackend | None = None)

TensorRT-LLM workspace for AllReduce fusion.

__init__(tp_size: int, tp_rank: int, max_token_num: int, hidden_dim: int, dtype: dtype = torch.float16, comm_backend: CommBackend | None = None)

Create TensorRT-LLM AllReduce fusion workspace.

Parameters:
  • tp_size – Tensor parallel size (world size)

  • tp_rank – Tensor parallel rank

  • max_token_num – Maximum number of tokens

  • hidden_dim – Hidden dimension size

  • dtype – Data type

  • comm_backend – Communication backend

  • **kwargs – Additional arguments for workspace creation

Methods

__init__(tp_size, tp_rank, max_token_num, ...)

Create TensorRT-LLM AllReduce fusion workspace.

destroy()

Destroy workspace and free resources.

is_buffer_size_sufficient(tp_size, ...[, ...])

Attributes

backend

Return backend name.

world_size

rank