flashinfer.sampling.top_k_renorm_prob#

flashinfer.sampling.top_k_renorm_prob(probs: torch.Tensor, top_k: int, eps: float = 1e-05)#

Fused GPU kernel for renormalizing probabilities by top-k thresholding.

Parameters:
  • probs (torch.Tensor) – Probabilities, shape (batch_size, num_classes).

  • top_k (int) – The threshold for re-normalizing probabilities, should be in (0, num_classes). We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities.

  • eps (float) – The epsilon value for numerical stability.

Returns:

renorm_probs – Renormalized probabilities, shape (batch_size, num_classes).

Return type:

torch.Tensor

Note

This combination of top_k_renorm_prob and sampling_from_probs should be equivalent to top_k_sampling_from_probs.