flashinfer.sampling.top_k_renorm_prob#
- flashinfer.sampling.top_k_renorm_prob(probs: torch.Tensor, top_k: int, eps: float = 1e-05)#
Fused GPU kernel for renormalizing probabilities by top-k thresholding.
- Parameters:
probs (torch.Tensor) – Probabilities, shape
(batch_size, num_classes)
.top_k (int) – The threshold for re-normalizing probabilities, should be in
(0, num_classes)
. We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities.eps (float) – The epsilon value for numerical stability.
- Returns:
renorm_probs – Renormalized probabilities, shape
(batch_size, num_classes)
.- Return type:
torch.Tensor
Note
This combination of
top_k_renorm_prob
andsampling_from_probs
should be equivalent totop_k_sampling_from_probs
.