flashinfer.sampling.sampling_from_probs¶
- flashinfer.sampling.sampling_from_probs(probs: torch.Tensor, indices: torch.Tensor | None = None, deterministic: bool = True, generator: torch.Generator | None = None, check_nan: bool = False) torch.Tensor ¶
Fused GPU kernel for category sampling from probabilities.
- Parameters:
probs (torch.Tensor) – Probabilities for sampling. When indices is not provided, shape should be
(batch_size, num_classes)
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided, shape should be(unique_batch_size, num_classes)
where unique_batch_size is the number of unique probability distributions.indices (Optional[torch.Tensor]) – Optional indices tensor of shape
(batch_size,)
that maps each output to a row in probs. For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of probs.deterministic (bool) – Whether to use deterministic kernel implementation, default is
True
.generator (Optional[torch.Generator]) – A random number generator for the operation.
check_nan (bool) – Whether to check nan in
probs
, default isFalse
.
- Returns:
samples – Sampled categories, shape (batch_size,).
- Return type:
torch.Tensor
Examples
>>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 4 >>> vocab_size = 5 >>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) >>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) >>> norm_prob tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106], [0.2205, 0.0942, 0.2912, 0.3452, 0.0489], [0.2522, 0.1602, 0.2346, 0.1532, 0.2000], [0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0') >>> samples = flashinfer.sampling.sampling_from_probs(norm_prob) >>> samples tensor([1, 2, 1, 4], device='cuda:0', dtype=torch.int32)
Note
This function expects float32 inputs, and the output is int32.