flashinfer.logits_processor.TopK

class flashinfer.logits_processor.TopK(joint_topk_topp: bool = False, **params: Any)

Top-k filtering processor.

Keeps only the top-k highest probability tokens and masks out the rest.

TensorType.LOGITS -> TensorType.LOGITS | TensorType.PROBS -> TensorType.PROBS

Parameters:
  • joint_topk_topp (bool, optional, Compile-time) – Whether to enable joint top-k and top-p filtering when followed by TopP. Default is False.

  • top_k (int or torch.Tensor, Runtime) – Number of top tokens to keep. Can be a scalar or per-batch tensor.

Examples

>>> import torch
>>> from flashinfer.logits_processor import LogitsPipe, TopK, Sample, TensorType
>>> torch.manual_seed(42)
>>>
>>> # Top-k filtering on logits
>>> pipe = LogitsPipe([TopK()], input_type=TensorType.LOGITS)
>>> logits = torch.randn(2, 2, device="cuda")
>>> logits
tensor([[ 0.1940,  2.1614], [ -0.1721,  0.8491]], device='cuda:0')
>>> topk_logits = pipe(logits, top_k=1)
>>> topk_logits
tensor([[  -inf, 2.1614], [  -inf, 0.8491]], device='cuda:0')
>>>
>>> # Top-k filtering on probabilities
>>> pipe = LogitsPipe([TopK()], input_type=TensorType.PROBS)
>>> probs = torch.randn(2, 2, device="cuda")
>>> probs_normed = probs / probs.sum(dim=-1, keepdim=True)
>>> probs_normed
tensor([[  4.4998,  -3.4998], [-18.2893,  19.2893]], device='cuda:0')
>>> topk_probs = pipe(probs_normed, top_k=1)
>>> topk_probs
tensor([[1., 0.], [0., 1.]], device='cuda:0')

Notes

When applied to TensorType.LOGITS, sets non-top-k values to -inf. When applied to TensorType.PROBS, zeros out non-top-k values and renormalizes.

__init__(joint_topk_topp: bool = False, **params: Any)

Constructor for TopK processor.

Parameters:

joint_topk_topp (bool, optional, Compile-time) – Whether to enable joint top-k and top-p filtering when followed by TopP. Default is False.

Methods

__init__([joint_topk_topp])

Constructor for TopK processor.

legalize(input_type)

Legalize the processor into a list of low-level operators.