flashinfer.comm.moe_a2a_wrap_payload_tensor_in_workspace¶
- flashinfer.comm.moe_a2a_wrap_payload_tensor_in_workspace(workspace: Tensor, leading_shape: list[int], slice_start: int, slice_end: int, dtype: dtype) Tensor¶
Wrap a slice of the shared workspace as a typed tensor view.
- Parameters:
workspace (torch.Tensor) –
[ep_size, size_per_rank](or[size_per_rank]) workspace tensor.leading_shape (list[int]) – Leading shape of the resulting view. The trailing dimension is inferred from
slice_end - slice_startanddtype.slice_start (int) – Start offset (in bytes from the beginning of the workspace) of the slice to wrap.
slice_end (int) – End offset (in bytes) of the slice. Must lie within a single rank.
dtype (torch.dtype) – Element dtype of the resulting view.
- Returns:
A workspace-backed tensor of shape
leading_shape + [-1].- Return type:
torch.Tensor