flashinfer.logits_processor.TopP

class flashinfer.logits_processor.TopP(**params: Any)

Top-p (nucleus) filtering processor.

Keeps tokens with cumulative probability up to threshold p.

TensorType.PROBS -> TensorType.PROBS

Parameters:

top_p (float or torch.Tensor, Runtime) – Cumulative probability threshold in (0, 1]. Can be a scalar or per-batch tensor.

Examples

>>> import torch
>>> from flashinfer.logits_processor import LogitsPipe, Softmax, TopP, Sample
>>> torch.manual_seed(42)
>>> pipe = LogitsPipe([TopP()])
>>> probs = torch.randn(2, 2, device="cuda")
>>> probs_normed = probs / probs.sum(dim=-1, keepdim=True)
>>> probs_normed
tensor([[ 0.0824,  0.9176], [-0.2541,  1.2541]], device='cuda:0')
>>> topp_probs = pipe(probs_normed, top_p=0.9)
>>> topp_probs
tensor([[0., 1.], [0., 1.]], device='cuda:0')
__init__(**params: Any)

Constructor for TopP processor. No compile-time parameters are needed.

Methods

__init__(**params)

Constructor for TopP processor.

legalize(input_type)

Legalize the processor into a list of low-level operators.