flashinfer.sampling.sampling_from_probs¶
- flashinfer.sampling.sampling_from_probs(probs: Tensor, indices: Tensor | None = None, deterministic: bool = True, generator: Generator | None = None, check_nan: bool = False, seed: int | Tensor | None = None, offset: int | Tensor | None = None) Tensor¶
Fused GPU kernel for category sampling from probabilities.
- Parameters:
probs (torch.Tensor) – Probabilities for sampling. When indices is not provided, shape should be
(batch_size, num_classes)and the i-th output will be sampled from the i-th row of probabilities. When indices is provided, shape should be(unique_batch_size, num_classes)where unique_batch_size is the number of unique probability distributions.indices (Optional[torch.Tensor]) – Optional indices tensor of shape
(batch_size,), dtypetorch.int32ortorch.int64that maps each output to a row in probs. The output tensor will have the same dtype as indices. For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of probs and output dtype defaults totorch.int32.deterministic (bool) – Whether to use deterministic kernel implementation, default is
True.generator (Optional[torch.Generator]) – A random number generator for the operation.
check_nan (bool) – Whether to check nan in
probs, default isFalse.seed (Optional[Union[int, torch.Tensor]]) –
Random seed value for the sampling operation. Can be either an integer or a torch.Tensor. When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. Using torch.Tensor is required for CUDA graph compatibility.
Warning: If you provide seed and offset explicitly, you are responsible for updating their values between calls to ensure different random samples. Common approaches include: - Incrementing offset by the number of random values consumed - Updating seed based on the number of calls to the operation
offset (Optional[Union[int, torch.Tensor]]) –
Random offset value for the sampling operation. Can be either an integer or a torch.Tensor. When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. Using torch.Tensor is required for CUDA graph compatibility.
Warning: If you provide seed and offset explicitly, you are responsible for updating their values between calls to ensure different random samples. The offset should be incremented based on the number of random values consumed by the operation.
- Returns:
samples – Sampled categories, shape (batch_size,).
- Return type:
torch.Tensor
Examples
>>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 4 >>> vocab_size = 5 >>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) >>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) >>> norm_prob tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106], [0.2205, 0.0942, 0.2912, 0.3452, 0.0489], [0.2522, 0.1602, 0.2346, 0.1532, 0.2000], [0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0') >>> samples = flashinfer.sampling.sampling_from_probs(norm_prob) >>> samples tensor([1, 2, 1, 4], device='cuda:0', dtype=torch.int32)
Note
This function expects float32 inputs, and the output is int32.