flashinfer.sampling.top_p_renorm_prob#
- flashinfer.sampling.top_p_renorm_prob(probs: torch.Tensor, top_p: float, eps: float = 1e-05)#
Fused GPU kernel for renormalizing probabilities by top-p thresholding.
- Parameters:
probs (torch.Tensor) – Probabilities, shape
(batch_size, num_classes)
.top_p (float) – The threshold for re-normalizing probabilities, should be in
(0, 1)
. We mask out the probabilities less than threshold where the cumulative sum ofprobs[probs >= threshold]
is top_p, and renormalize the probabilities.eps (float) – The epsilon value for numerical stability.
- Returns:
renorm_probs (torch.Tensor) – Renormalized probabilities, shape
(batch_size, num_classes)
.This combination of
top_p_renorm_prob
andsampling_from_probs
should be equivalent totop_p_sampling_from_probs
.