flashinfer.sampling.sampling_from_probs

flashinfer.sampling.sampling_from_probs(probs: Tensor, indices: Tensor | None = None, deterministic: bool = True, generator: Generator | None = None, check_nan: bool = False, seed: int | Tensor | None = None, offset: int | Tensor | None = None) Tensor

Fused GPU kernel for category sampling from probabilities.

Parameters:
  • probs (torch.Tensor) – Probabilities for sampling. When indices is not provided, shape should be (batch_size, num_classes) and the i-th output will be sampled from the i-th row of probabilities. When indices is provided, shape should be (unique_batch_size, num_classes) where unique_batch_size is the number of unique probability distributions.

  • indices (Optional[torch.Tensor]) – Optional indices tensor of shape (batch_size,), dtype torch.int32 or torch.int64 that maps each output to a row in probs. The output tensor will have the same dtype as indices. For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of probs and output dtype defaults to torch.int32.

  • deterministic (bool) – Whether to use deterministic kernel implementation, default is True.

  • generator (Optional[torch.Generator]) – A random number generator for the operation.

  • check_nan (bool) – Whether to check nan in probs, default is False.

  • seed (Optional[Union[int, torch.Tensor]]) –

    Random seed value for the sampling operation. Can be either an integer or a torch.Tensor. When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. Using torch.Tensor is required for CUDA graph compatibility.

    Warning: If you provide seed and offset explicitly, you are responsible for updating their values between calls to ensure different random samples. Common approaches include: - Incrementing offset by the number of random values consumed - Updating seed based on the number of calls to the operation

  • offset (Optional[Union[int, torch.Tensor]]) –

    Random offset value for the sampling operation. Can be either an integer or a torch.Tensor. When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. Using torch.Tensor is required for CUDA graph compatibility.

    Warning: If you provide seed and offset explicitly, you are responsible for updating their values between calls to ensure different random samples. The offset should be incremented based on the number of random values consumed by the operation.

Returns:

samples – Sampled categories, shape (batch_size,).

Return type:

torch.Tensor

Examples

>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
>>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
>>> norm_prob
tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106],
        [0.2205, 0.0942, 0.2912, 0.3452, 0.0489],
        [0.2522, 0.1602, 0.2346, 0.1532, 0.2000],
        [0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0')
>>> samples = flashinfer.sampling.sampling_from_probs(norm_prob)
>>> samples
tensor([1, 2, 1, 4], device='cuda:0', dtype=torch.int32)

Note

This function expects float32 inputs, and the output is int32.