flashinfer.sampling.top_p_renorm_probs#
- flashinfer.sampling.top_p_renorm_probs(probs: torch.Tensor, top_p: torch.Tensor | float) torch.Tensor #
Fused GPU kernel for renormalizing probabilities by top-p thresholding.
- Parameters:
probs (torch.Tensor) – Probabilities, shape
(batch_size, num_classes)
.top_p (Union[torch.Tensor, float]) – Either a scalar or a tensor of shape
(batch_size,)
, representing the top-p threshold for for re-normalizing probabilities, should be in(0, 1)
. If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. We mask out the probabilities less than threshold where the cumulative sum ofprobs[probs >= threshold]
is top_p, and renormalize the probabilities.
- Returns:
renorm_probs – Renormalized probabilities, shape
(batch_size, num_classes)
.- Return type:
torch.Tensor
Examples
>>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 4 >>> vocab_size = 5 >>> top_p = 0.3 >>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) >>> prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) >>> prob tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106], [0.2205, 0.0942, 0.2912, 0.3452, 0.0489], [0.2522, 0.1602, 0.2346, 0.1532, 0.2000], [0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0') >>> renormed_probs = flashinfer.sampling.top_p_renorm_probs(prob, top_p) >>> renormed_probs tensor([[0.0000, 0.4882, 0.0000, 0.5118, 0.0000], [0.0000, 0.0000, 0.0000, 1.0000, 0.0000], [0.5181, 0.0000, 0.4819, 0.0000, 0.0000], [0.0000, 1.0000, 0.0000, 0.0000, 0.0000]], device='cuda:0')
Note
This combination of
top_p_renorm_probs
andsampling_from_probs
should be equivalent totop_p_sampling_from_probs
.