flashinfer.logits_processor.TopP¶
- class flashinfer.logits_processor.TopP(**params: Any)¶
Top-p (nucleus) filtering processor.
Keeps tokens with cumulative probability up to threshold p.
TensorType.PROBS
->TensorType.PROBS
- Parameters:
top_p (float or torch.Tensor, Runtime) – Cumulative probability threshold in (0, 1]. Can be a scalar or per-batch tensor.
Examples
>>> import torch >>> from flashinfer.logits_processor import LogitsPipe, Softmax, TopP, Sample >>> torch.manual_seed(42) >>> pipe = LogitsPipe([TopP()]) >>> probs = torch.randn(2, 2, device="cuda") >>> probs_normed = probs / probs.sum(dim=-1, keepdim=True) >>> probs_normed tensor([[ 0.0824, 0.9176], [-0.2541, 1.2541]], device='cuda:0') >>> topp_probs = pipe(probs_normed, top_p=0.9) >>> topp_probs tensor([[0., 1.], [0., 1.]], device='cuda:0')
See also
- __init__(**params: Any)¶
Constructor for TopP processor. No compile-time parameters are needed.
Methods
__init__
(**params)Constructor for TopP processor.
legalize
(input_type)Legalize the processor into a list of low-level operators.