flashinfer.logits_processor.LogitsProcessor

class flashinfer.logits_processor.LogitsProcessor(**params: Any)

LogitsProcessor defines high-level transformations that can be applied to logits or probabilities. Each processor is automatically legalized into low-level Op or ParameterizedOp that can be type-checked, validated, and fused for optimal performance. Users can extend this class to implement their own processors.

Parameters:

**params (Any) – Processor-specific parameters at compile-time.

Examples

>>> import torch
>>> from flashinfer.logits_processor import LogitsPipe, TopK, Sample, TensorType
>>> torch.manual_seed(42)
>>>
>>> # Create a pipeline that legalizes to a fused op.
>>> pipe = LogitsPipe([
...     TopK(),         # Top-k filtering on logits
...     Sample()        # Sample from the filtered distribution
... ], input_type=TensorType.PROBS)  # assume the input is probabilities
>>>
>>> pipe
LogitsPipe([TopK -> Sample], ops=[ProbsTopKOp -> ProbsSampleOp], compiled_ops=[FusedProbsTopKSampleOp])

Notes

Subclasses must implement the legalize() method to convert the high-level processor into one or more low-level operators with specific input/output types

__init__(**params: Any)

Initialize the processor.

Parameters:

**params (Any) – Processor-specific parameters at compile-time.

Methods

__init__(**params)

Initialize the processor.

legalize(input_type)

Legalize the processor into a list of low-level operators.