flashinfer.comm.TRTLLMAllReduceFusionWorkspace¶
- class flashinfer.comm.TRTLLMAllReduceFusionWorkspace(tp_size: int, tp_rank: int, max_token_num: int, hidden_dim: int, dtype: dtype = torch.float16, comm_backend: CommBackend | None = None)¶
TensorRT-LLM workspace for AllReduce fusion.
- __init__(tp_size: int, tp_rank: int, max_token_num: int, hidden_dim: int, dtype: dtype = torch.float16, comm_backend: CommBackend | None = None)¶
Create TensorRT-LLM AllReduce fusion workspace.
- Parameters:
tp_size – Tensor parallel size (world size)
tp_rank – Tensor parallel rank
max_token_num – Maximum number of tokens
hidden_dim – Hidden dimension size
dtype – Data type
comm_backend – Communication backend
**kwargs – Additional arguments for workspace creation
Methods
__init__(tp_size, tp_rank, max_token_num, ...)Create TensorRT-LLM AllReduce fusion workspace.
destroy()Destroy workspace and free resources.
is_buffer_size_sufficient(tp_size, ...[, ...])Attributes
backendReturn backend name.
world_sizerank