flashinfer.comm.moe_a2a_wrap_payload_tensor_in_workspace

flashinfer.comm.moe_a2a_wrap_payload_tensor_in_workspace(workspace: Tensor, leading_shape: list[int], slice_start: int, slice_end: int, dtype: dtype) Tensor

Wrap a slice of the shared workspace as a typed tensor view.

Parameters:
  • workspace (torch.Tensor) – [ep_size, size_per_rank] (or [size_per_rank]) workspace tensor.

  • leading_shape (list[int]) – Leading shape of the resulting view. The trailing dimension is inferred from slice_end - slice_start and dtype.

  • slice_start (int) – Start offset (in bytes from the beginning of the workspace) of the slice to wrap.

  • slice_end (int) – End offset (in bytes) of the slice. Must lie within a single rank.

  • dtype (torch.dtype) – Element dtype of the resulting view.

Returns:

A workspace-backed tensor of shape leading_shape + [-1].

Return type:

torch.Tensor