flashinfer.sampling.top_k_renorm_probs

flashinfer.sampling.top_k_renorm_probs(probs: torch.Tensor, top_k: torch.Tensor | int) torch.Tensor

Fused GPU kernel for renormalizing probabilities by top-k thresholding.

Parameters:
  • probs (torch.Tensor) – Probabilities, shape (batch_size, num_classes).

  • top_k (Union[torch.Tensor, int]) – Either a scalar or a tensor of shape (batch_size,), representing the top-k threshold for for for re-normalizing probabilities, should be in (0, num_classes). If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities.

Returns:

renorm_probs – Renormalized probabilities, shape (batch_size, num_classes).

Return type:

torch.Tensor

Examples

>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> top_k = 3
>>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
>>> prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
>>> prob
tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106],
        [0.2205, 0.0942, 0.2912, 0.3452, 0.0489],
        [0.2522, 0.1602, 0.2346, 0.1532, 0.2000],
        [0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0')
>>> renormed_probs = flashinfer.sampling.top_k_renorm_probs(prob, top_k)
>>> renormed_probs
tensor([[0.3201, 0.3319, 0.0000, 0.3480, 0.0000],
        [0.2573, 0.0000, 0.3398, 0.4028, 0.0000],
        [0.3672, 0.0000, 0.3416, 0.0000, 0.2912],
        [0.0000, 0.4243, 0.2750, 0.0000, 0.3007]], device='cuda:0')

Note

This combination of top_k_renorm_probs and sampling_from_probs should be equivalent to top_k_sampling_from_probs.