flashinfer.sampling.top_k_top_p_sampling_from_logits

flashinfer.sampling.top_k_top_p_sampling_from_logits(logits: Tensor, top_k: Tensor | int, top_p: Tensor | float, indices: Tensor | None = None, filter_apply_order: str = 'top_k_first', deterministic: bool = True, generator: Generator | None = None, check_nan: bool = False, seed: int | Tensor | None = None, offset: int | Tensor | None = None) Tensor

Fused GPU kernel for top-k and top-p sampling from pre-softmax logits,

this operator implements GPU-based rejection sampling without explicit sorting. Check the blog post for more details.

The multiple rounds of rejection sampling are implemented in a single CUDA kernel, which is more efficient than the naive implementation that launches a series of kernels.

Parameters:
  • logits (torch.Tensor) – Pre-softmax logits for sampling. When indices is not provided, shape should be (batch_size, num_classes) and the i-th output will be sampled from the i-th row of logits. When indices is provided, shape should be (unique_batch_size, num_classes) where unique_batch_size is the number of unique probability distributions.

  • top_k (Union[torch.Tensor, int]) – Either a scalar or a tensor of shape (batch_size,), representing the threshold for top-k sampling. If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold.

  • top_p (Union[torch.Tensor, float]) – Either a scalar or a tensor of shape (batch_size,), representing the threshold for top-p sampling. If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold.

  • indices (Optional[torch.Tensor]) – Optional indices tensor of shape (batch_size,), dtype torch.int32 or torch.int64 that maps each output to a row in probs. The output tensor will have the same dtype as indices. For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of probs and output dtype defaults to torch.int32.

  • filter_apply_order (str) – The order of applying top-k and top-p sampling, should be either "top_k_first" or "joint". If "top_k_first", we first apply top-k filter, then apply top-p sampling on the top-k results. If "joint", we apply top-k and top-p filter simultaneously in each round. Default is "top_k_first".

  • deterministic (bool) – Whether to use deterministic kernel implementation, default is True.

  • generator (Optional[torch.Generator]) – A random number generator for the operation.

  • check_nan (bool) – Whether to check nan in probs, default is False.

  • seed (Optional[Union[int, torch.Tensor]]) –

    Random seed value for the sampling operation. Can be either an integer or a torch.Tensor. When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. Using torch.Tensor is required for CUDA graph compatibility.

    Warning: If you provide seed and offset explicitly, you are responsible for updating their values between calls to ensure different random samples. Common approaches include: - Incrementing offset by the number of random values consumed - Updating seed based on the number of calls to the operation

  • offset (Optional[Union[int, torch.Tensor]]) –

    Random offset value for the sampling operation. Can be either an integer or a torch.Tensor. When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. Using torch.Tensor is required for CUDA graph compatibility.

    Warning: If you provide seed and offset explicitly, you are responsible for updating their values between calls to ensure different random samples. The offset should be incremented based on the number of random values consumed by the operation.

Returns:

samples – Sampled categories, shape (batch_size,).

Return type:

torch.Tensor

Examples

>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> top_p = 0.5
>>> top_k = 3
>>> logits = torch.rand(batch_size, vocab_size).to(0)
>>> logits
tensor([[ 1.9269,  1.4873,  0.9007, -2.1055, -0.7581],
        [ 1.0783,  0.8008,  1.6806,  0.3559, -0.6866],
        [-0.4934,  0.2415, -0.2316,  0.0418, -0.2516],
        [ 0.8599, -0.3097, -0.3957,  0.8034, -0.6216]], device='cuda:0')
>>> samples = flashinfer.sampling.top_k_top_p_sampling_from_logits(logits, top_k, top_p)
>>> samples
tensor([0, 2, 1, 3], device='cuda:0', dtype=torch.int32
>>> probs = torch.softmax(logits, dim=-1)
>>> probs
tensor([[0.4788, 0.3085, 0.1716, 0.0085, 0.0327],
    [0.2358, 0.1787, 0.4307, 0.1145, 0.0404],
    [0.1358, 0.2831, 0.1764, 0.2318, 0.1729],
    [0.3613, 0.1122, 0.1029, 0.3415, 0.0821]], device='cuda:0')
>>> samples
tensor([0, 2, 1, 3], device='cuda:0', dtype=torch.int32)

Note

This function expects float32 inputs, and the output is int32.