flashinfer.logits_processor.TaggedTensor¶
- class flashinfer.logits_processor.TaggedTensor(data: torch.Tensor, type: TensorType)¶
Tensor wrapper that maintains semantic type information through pipeline execution.
It ensures type safety as tensors flow through the logits processing pipeline and zero friction for downstream PyTorch code
Notes
TaggedTensor is primarily for internal use by
LogitsPipe
Users typically work with plain tensors; tagging happens automatically
- __init__(data: torch.Tensor, type: TensorType) None ¶
Methods
__init__
(data, type)indices
(t)Create a TaggedTensor with type
TensorType.INDICES
.logits
(t)Create a TaggedTensor with type
TensorType.LOGITS
.probs
(t)Create a TaggedTensor with type
TensorType.PROBS
.size
([dim])Get the size of the underlying tensor.
Attributes
data
The underlying tensor.
type
The semantic type of the tensor.
device
Get the device of the underlying tensor.
dtype
Get the data type of the underlying tensor.
shape
Get the shape of the underlying tensor.