flashinfer.logits_processor¶
A declarative, pluggable framework for building processing pipelines for LLM outputs.
Pipeline Construction¶
Use LogitsPipe
to create processing pipelines:
import torch
from flashinfer.logits_processor import LogitsPipe, Temperature, Softmax, TopP, Sample
# Create a pipeline
pipe = LogitsPipe([
Temperature(), # Scale logits by temperature
Softmax(), # Convert logits to probabilities
TopP(), # Apply top-p filtering
Sample() # Sample from the distribution
])
# Apply the pipeline
batch_size = 4
vocab_size = 5
logits = torch.randn(batch_size, vocab_size, device="cuda")
output_ids = pipe(logits, temperature=0.7, top_p=0.9)
Pipeline¶
|
Provides a declarative way to build processing pipelines for LLM outputs. |
Processors¶
|
LogitsProcessor defines high-level transformations that can be applied to logits or probabilities. |
|
Temperature scaling processor for logits. |
|
Softmax processor to convert logits to probabilities. |
|
Top-k filtering processor. |
|
Top-p (nucleus) filtering processor. |
|
Min-p filtering processor. |
|
Sampling processor to generate token indices. |
Types¶
|
TensorType represents the semantic type of tensors in the pipeline. |
|
Tensor wrapper that maintains semantic type information through pipeline execution. |
Customization Features¶
Custom Logits Processor¶
You can create your own logits processor by subclassing LogitsProcessor
:
class CustomLogitsProcessor(LogitsProcessor):
def __init__(self, **params: Any):
super().__init__(**params)
def legalize(self, input_type: TensorType) -> List["Op"]:
return [CustomOp(**self.params)]
class CustomOp(Op):
# Define the input and output tensor types
IN = TensorType.LOGITS
OUT = TensorType.LOGITS
def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor:
pass
pipe = LogitsPipe([CustomLogitsProcessor()]) # The pipe will be compiled into [CustomOp]
Custom Fusion Rules¶
You can register custom fusion rules to optimize specific processor combinations:
def custom_fusion_guard(window: List[Op]) -> bool:
# Whether the fusion should be applied
return True
def build_custom_fusion(window: List[Op]) -> Op:
# Create a fused operator by setting the parameters etc.
return CustomOp()
custom_rule = FusionRule(
pattern=(Temperature, Softmax),
guard=custom_fusion_guard,
build=build_custom_fusion,
prio=20
)
pipe = LogitsPipe(
[Temperature(), Softmax(), Sample()],
custom_fusion_rules=[custom_rule]
) # The compiled ops in the pipeline will be [CustomOp, Sample]