flashinfer.sampling.top_k_renorm_probs#
- flashinfer.sampling.top_k_renorm_probs(probs: torch.Tensor, top_k: torch.Tensor | int) torch.Tensor #
Fused GPU kernel for renormalizing probabilities by top-k thresholding.
- Parameters:
probs (torch.Tensor) – Probabilities, shape
(batch_size, num_classes)
.top_k (Union[torch.Tensor, int]) – Either a scalar or a tensor of shape
(batch_size,)
, representing the top-k threshold for for for re-normalizing probabilities, should be in(0, num_classes)
. If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities.
- Returns:
renorm_probs – Renormalized probabilities, shape
(batch_size, num_classes)
.- Return type:
torch.Tensor
Examples
>>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 4 >>> vocab_size = 5 >>> top_k = 3 >>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) >>> prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) >>> prob tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106], [0.2205, 0.0942, 0.2912, 0.3452, 0.0489], [0.2522, 0.1602, 0.2346, 0.1532, 0.2000], [0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0') >>> renormed_probs = flashinfer.sampling.top_k_renorm_probs(prob, top_k) >>> renormed_probs tensor([[0.3201, 0.3319, 0.0000, 0.3480, 0.0000], [0.2573, 0.0000, 0.3398, 0.4028, 0.0000], [0.3672, 0.0000, 0.3416, 0.0000, 0.2912], [0.0000, 0.4243, 0.2750, 0.0000, 0.3007]], device='cuda:0')
Note
This combination of
top_k_renorm_probs
andsampling_from_probs
should be equivalent totop_k_sampling_from_probs
.