flashinfer.logits_processor.LogitsPipe

class flashinfer.logits_processor.LogitsPipe(processors: List[LogitsProcessor], compile: bool = True, input_type: TensorType | None = None, custom_fusion_rules: List[FusionRule] | None = None, custom_validity_checks: List[Callable[[List[Op]], None]] | None = None)

Provides a declarative way to build processing pipelines for LLM outputs.

Parameters:
  • processors (List[LogitsProcessor]) – List of processors to apply in sequence.

  • compile (bool, optional) – Whether to compile the pipeline with fusion optimizations. Default is True. LogitsPipe.compile() can be called to perform compilation after pipeline instantiation.

  • input_type (Optional[TensorType], optional) – Expected input tensor type. It can be TensorType.LOGITS or TensorType.PROBS. It’s required if the first processor can take both types. In other cases, it will be automatically inferred from the first processor. Default is None.

  • custom_fusion_rules (Optional[List[FusionRule]], optional) – Additional fusion rules to apply during compilation. Default is None.

  • custom_validity_checks (Optional[List[ValidityCheck]], optional) – Additional validity checks to apply during compilation. Default is None.

Examples

>>> import torch
>>> from flashinfer.logits_processor import LogitsPipe, Temperature, Softmax, TopK, Sample
>>> torch.manual_seed(42)
>>>
>>> # Basic pipeline with temperature, top-k, and sampling
>>> pipe = LogitsPipe([
...     Temperature(),
...     Softmax(),
...     TopK(),
...     Sample(deterministic=True)
... ])
>>>
>>> # Execute the pipeline
>>> logits = torch.randn(4, 32000, device="cuda")  # [batch_size, vocab_size]
>>> token_ids = pipe(logits, temperature=0.9, top_k=40)
>>> token_ids
tensor([15806,  8154, 13923, 20311], device='cuda:0')
>>>
>>> # Pipeline starting from probabilities
>>> from flashinfer.logits_processor import TensorType, TopK
>>>
>>> prob_pipe = LogitsPipe(
...     [TopK(), Sample()],
...     input_type=TensorType.PROBS
... )
>>> probs = torch.softmax(logits, dim=-1)
>>> token_ids = prob_pipe(probs, top_k=40)
>>> token_ids
tensor([  346, 14846,  1517,  9006], device='cuda:0')

Notes

  • The pipeline automatically validates type compatibility between operations.

  • Operations are fused when possible

  • Runtime parameters (like temperature, top_k) are passed with pipe.call().

  • The output is always a plain torch.Tensor, not a TaggedTensor.

__init__(processors: List[LogitsProcessor], compile: bool = True, input_type: TensorType | None = None, custom_fusion_rules: List[FusionRule] | None = None, custom_validity_checks: List[Callable[[List[Op]], None]] | None = None)

Constructor for a LogitsPipe.

Methods

__init__(processors[, compile, input_type, ...])

Constructor for a LogitsPipe.

compile([custom_fusion_rules, ...])

Compile the pipeline.

Attributes

initial_type

The initial input tensor type of the pipeline.