flashinfer.sampling.top_k_top_p_sampling_from_probs

flashinfer.sampling.top_k_top_p_sampling_from_probs(probs: torch.Tensor, top_k: torch.Tensor | int, top_p: torch.Tensor | float, indices: torch.Tensor | None = None, filter_apply_order: str = 'top_k_first', deterministic: bool = True, generator: torch.Generator | None = None, check_nan: bool = False) torch.Tensor

Fused GPU kernel for top-k and top-p sampling from probabilities,

this operator implements GPU-based rejection sampling without explicit sorting. Check the blog post for more details.

The multiple rounds of rejection sampling are implemented in a single CUDA kernel, which is more efficient than the naive implementation that launches a series of kernels.

Parameters:
  • probs (torch.Tensor) – Probabilities for sampling. When indices is not provided, shape should be (batch_size, num_classes) and the i-th output will be sampled from the i-th row of probabilities. When indices is provided, shape should be (unique_batch_size, num_classes) where unique_batch_size is the number of unique probability distributions.

  • top_k (Union[torch.Tensor, int]) – Either a scalar or a tensor of shape (batch_size,), representing the threshold for top-k sampling. If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold.

  • top_p (Union[torch.Tensor, float]) – Either a scalar or a tensor of shape (batch_size,), representing the threshold for top-p sampling. If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold.

  • indices (Optional[torch.Tensor]) – Optional indices tensor of shape (batch_size,) that maps each output to a row in probs. For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of probs.

  • filter_apply_order (str) – The order of applying top-k and top-p sampling, should be either "top_k_first" or "joint". If "top_k_first", we first apply top-k filter, then apply top-p sampling on the top-k results. If "joint", we apply top-k and top-p filter simultaneously in each round. Default is "top_k_first".

  • deterministic (bool) – Whether to use deterministic kernel implementation, default is True.

  • generator (Optional[torch.Generator]) – A random number generator for the operation.

  • check_nan (bool) – Whether to check nan in probs, default is False.

Returns:

samples – Sampled categories, shape (batch_size,).

Return type:

torch.Tensor

Examples

>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> top_p = torch.full((batch_size,), 0.2).to(0)
>>> top_k = torch.full((batch_size,), 2).to(0)
>>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
>>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
>>> norm_prob
tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106],
        [0.2205, 0.0942, 0.2912, 0.3452, 0.0489],
        [0.2522, 0.1602, 0.2346, 0.1532, 0.2000],
        [0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0')
>>> samples = flashinfer.sampling.top_k_top_p_sampling_from_probs(norm_prob, top_k, top_p)
>>> samples
tensor([3, 3, 0, 1], device='cuda:0', dtype=torch.int32)

Note

This function expects float32 inputs, and the output is int32.