flashinfer.logits_processor.TopK¶
- class flashinfer.logits_processor.TopK(joint_topk_topp: bool = False, **params: Any)¶
Top-k filtering processor.
Keeps only the top-k highest probability tokens and masks out the rest.
TensorType.LOGITS
->TensorType.LOGITS
|TensorType.PROBS
->TensorType.PROBS
- Parameters:
joint_topk_topp (bool, optional, Compile-time) – Whether to enable joint top-k and top-p filtering when followed by TopP. Default is False.
top_k (int or torch.Tensor, Runtime) – Number of top tokens to keep. Can be a scalar or per-batch tensor.
Examples
>>> import torch >>> from flashinfer.logits_processor import LogitsPipe, TopK, Sample, TensorType >>> torch.manual_seed(42) >>> >>> # Top-k filtering on logits >>> pipe = LogitsPipe([TopK()], input_type=TensorType.LOGITS) >>> logits = torch.randn(2, 2, device="cuda") >>> logits tensor([[ 0.1940, 2.1614], [ -0.1721, 0.8491]], device='cuda:0') >>> topk_logits = pipe(logits, top_k=1) >>> topk_logits tensor([[ -inf, 2.1614], [ -inf, 0.8491]], device='cuda:0') >>> >>> # Top-k filtering on probabilities >>> pipe = LogitsPipe([TopK()], input_type=TensorType.PROBS) >>> probs = torch.randn(2, 2, device="cuda") >>> probs_normed = probs / probs.sum(dim=-1, keepdim=True) >>> probs_normed tensor([[ 4.4998, -3.4998], [-18.2893, 19.2893]], device='cuda:0') >>> topk_probs = pipe(probs_normed, top_k=1) >>> topk_probs tensor([[1., 0.], [0., 1.]], device='cuda:0')
Notes
When applied to
TensorType.LOGITS
, sets non-top-k values to -inf. When applied toTensorType.PROBS
, zeros out non-top-k values and renormalizes.See also
- __init__(joint_topk_topp: bool = False, **params: Any)¶
Constructor for TopK processor.
- Parameters:
joint_topk_topp (bool, optional, Compile-time) – Whether to enable joint top-k and top-p filtering when followed by TopP. Default is False.
Methods
__init__
([joint_topk_topp])Constructor for TopK processor.
legalize
(input_type)Legalize the processor into a list of low-level operators.