flashinfer.sampling.top_p_renorm_probs

flashinfer.sampling.top_p_renorm_probs(probs: torch.Tensor, top_p: torch.Tensor | float) torch.Tensor

Fused GPU kernel for renormalizing probabilities by top-p thresholding.

Parameters:
  • probs (torch.Tensor) – Probabilities, shape (batch_size, num_classes).

  • top_p (Union[torch.Tensor, float]) – Either a scalar or a tensor of shape (batch_size,), representing the top-p threshold for for re-normalizing probabilities, should be in (0, 1). If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. We mask out the probabilities less than threshold where the cumulative sum of probs[probs >= threshold] is top_p, and renormalize the probabilities.

Returns:

renorm_probs – Renormalized probabilities, shape (batch_size, num_classes).

Return type:

torch.Tensor

Examples

>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> top_p = 0.3
>>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
>>> prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
>>> prob
tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106],
        [0.2205, 0.0942, 0.2912, 0.3452, 0.0489],
        [0.2522, 0.1602, 0.2346, 0.1532, 0.2000],
        [0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0')
>>> renormed_probs = flashinfer.sampling.top_p_renorm_probs(prob, top_p)
>>> renormed_probs
tensor([[0.0000, 0.4882, 0.0000, 0.5118, 0.0000],
        [0.0000, 0.0000, 0.0000, 1.0000, 0.0000],
        [0.5181, 0.0000, 0.4819, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000, 0.0000]], device='cuda:0')

Note

This combination of top_p_renorm_probs and sampling_from_probs should be equivalent to top_p_sampling_from_probs.