flashinfer.logits_processor.TaggedTensor

class flashinfer.logits_processor.TaggedTensor(data: torch.Tensor, type: TensorType)

Tensor wrapper that maintains semantic type information through pipeline execution.

It ensures type safety as tensors flow through the logits processing pipeline and zero friction for downstream PyTorch code

Notes

  • TaggedTensor is primarily for internal use by LogitsPipe

  • Users typically work with plain tensors; tagging happens automatically

__init__(data: torch.Tensor, type: TensorType) None

Methods

__init__(data, type)

indices(t)

Create a TaggedTensor with type TensorType.INDICES.

logits(t)

Create a TaggedTensor with type TensorType.LOGITS.

probs(t)

Create a TaggedTensor with type TensorType.PROBS.

size([dim])

Get the size of the underlying tensor.

Attributes

data

The underlying tensor.

type

The semantic type of the tensor.

device

Get the device of the underlying tensor.

dtype

Get the data type of the underlying tensor.

shape

Get the shape of the underlying tensor.