flashinfer.sampling.top_k_top_p_sampling_from_logits

flashinfer.sampling.top_k_top_p_sampling_from_logits(probs: torch.Tensor, uniform_samples: torch.Tensor, top_k: torch.Tensor | int, top_p: torch.Tensor | float, filter_apply_order: str = 'top_k_first', deterministic: bool = True, check_nan: bool = False) Tuple[torch.Tensor, torch.Tensor]

Fused GPU kernel for top-k and top-p sampling from pre-softmax logits,

this operator implements GPU-based rejection sampling without explicit sorting.

The multiple rounds of rejection sampling are implemented in a single CUDA kernel, which is more efficient than the naive implementation that launches a series of kernels.

Parameters:
  • logits (torch.Tensor) – Pre-softmax logits, shape (batch_size, num_classes).

  • uniform_samples (torch.Tensor) – The uniform samples used as needle for sampling, shape (max_top_k_rounds, batch_size,), where the first dimension is the maximum number of rounds for rejection sampling. Expected to be uniformly distributed in [0, 1).

  • top_k (Union[torch.Tensor, int]) – Either a scalar or a tensor of shape (batch_size,), representing the threshold for top-k sampling. If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold.

  • top_p (Union[torch.Tensor, float]) – Either a scalar or a tensor of shape (batch_size,), representing the threshold for top-p sampling. If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold.

  • filter_apply_order (str) – The order of applying top-k and top-p sampling, should be either "top_k_first" or "joint". If "top_k_first", we first apply top-k filter, then apply top-p sampling on the top-k results. If "joint", we apply top-k and top-p filter simultaneously in each round. Default is "top_k_first".

  • deterministic (bool) – Whether to use deterministic kernel implementation, default is True.

  • check_nan (bool) – Whether to check nan in probs, default is False.

Returns:

  • samples (torch.Tensor) – Sampled categories, shape (batch_size,).

  • success (torch.Tensor) – Whether the sampling is successful within max_top_k_rounds rounds, shape (batch_size,).

Examples

>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> max_rounds = 3
>>> top_p = 0.5
>>> top_k = 3
>>> logits = torch.rand(batch_size, vocab_size).to(0)
>>> logits
tensor([[ 1.9269,  1.4873,  0.9007, -2.1055, -0.7581],
        [ 1.0783,  0.8008,  1.6806,  0.3559, -0.6866],
        [-0.4934,  0.2415, -0.2316,  0.0418, -0.2516],
        [ 0.8599, -0.3097, -0.3957,  0.8034, -0.6216]], device='cuda:0')
>>> uniform_samples = torch.rand(max_rounds, batch_size).to(0)
>>> samples, success = flashinfer.sampling.top_k_top_p_sampling_from_logits(logits, uniform_samples, top_k, top_p)
>>> samples
tensor([0, 2, 1, 3], device='cuda:0', dtype=torch.int32
>>> success
tensor([True, True, True, True], device='cuda:0')
>>> probs = torch.softmax(logits, dim=-1)
>>> probs
tensor([[0.4788, 0.3085, 0.1716, 0.0085, 0.0327],
    [0.2358, 0.1787, 0.4307, 0.1145, 0.0404],
    [0.1358, 0.2831, 0.1764, 0.2318, 0.1729],
    [0.3613, 0.1122, 0.1029, 0.3415, 0.0821]], device='cuda:0')
>>> samples
tensor([0, 2, 1, 3], device='cuda:0', dtype=torch.int32)
>>> success
tensor([True, True, True, True], device='cuda:0')

Note

This function expects float32 inputs, and the output is int32. We encourage users to set max_rounds to a reasonable value, e.g., 32. The actual implementation usually use much fewer rounds for rejection sampling because of early stopping.