flashinfer.sampling.min_p_sampling_from_probs#

flashinfer.sampling.min_p_sampling_from_probs(probs: torch.Tensor, uniform_samples: torch.Tensor, min_p: torch.Tensor | float, deterministic: bool = True, check_nan: bool = False) Tuple[torch.Tensor, torch.Tensor]#

Fused GPU kernel for min_p sampling from probabilities,

this operator implements GPU-based rejection sampling without explicit sorting.

The multiple rounds of rejection sampling are implemented in a single CUDA kernel, which is more efficient than the naive implementation that launches a series of kernels.

Parameters:
  • probs (torch.Tensor) – Probabilities, shape (batch_size, num_classes).

  • uniform_samples (torch.Tensor) – The uniform samples used as needle for sampling, shape (max_top_k_rounds, batch_size,), where the first dimension is the maximum number of rounds for rejection sampling. Expected to be uniformly distributed in [0, 1).

  • min_p (torch.Tensor) – Either a scalar or a tensor of shape (batch_size,), representing the threshold for min-p sampling. If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold.

  • deterministic (bool) – Whether to use deterministic kernel implementation, default is True.

  • check_nan (bool) – Whether to check nan in probs, default is False.

Returns:

  • samples (torch.Tensor) – Sampled categories, shape (batch_size,).

  • success (torch.Tensor) – Whether the sampling is successful within max_top_k_rounds rounds, shape (batch_size,).

Examples

>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
<torch._C.Generator object at 0x7f8b3db06df0>
>>> batch_size = 4
>>> vocab_size = 5
>>> max_rounds = 3
>>> min_p = torch.full((batch_size,), 0.05).to(0)
>>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
>>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
>>> norm_prob
tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106],
        [0.2205, 0.0942, 0.2912, 0.3452, 0.0489],
        [0.2522, 0.1602, 0.2346, 0.1532, 0.2000],
        [0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0')
>>> uniform_samples = torch.rand(max_rounds, batch_size).to(0)
>>> samples, success = flashinfer.sampling.min_p_sampling_from_probs(norm_prob, uniform_samples, min_p)
>>> samples
tensor([1, 2, 1, 4], device='cuda:0', dtype=torch.int32)
>>> success
tensor([True, True, True, True], device='cuda:0')

Note

This function expects float32 inputs, and the output is int32. We encourage users to set max_rounds to a reasonable value, e.g., 32. The actual implementation usually use much fewer rounds for rejection sampling because of early stopping.