flashinfer.sampling.top_k_sampling_from_probs#
- flashinfer.sampling.top_k_sampling_from_probs(probs: torch.Tensor, uniform_samples: torch.Tensor, top_k: int)#
Fused GPU kernel for top-k sampling from probabilities, this operator implements GPU-based rejection sampling without explicit sorting.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel, which is more efficient than the naive implementation that launches a series of kernels.
- Parameters:
probs (torch.Tensor) – Probabilities, shape
(batch_size, num_classes)
.uniform_samples (torch.Tensor) – The uniform samples used as needle for sampling, shape
(max_top_k_rounds, batch_size,)
, where the first dimension is the maximum number of rounds for rejection sampling. Expected to be uniformly distributed in[0, 1)
.top_k (int) – The k in “top-k”.
- Returns:
(samples, success) –
- samples: torch.Tensor
Sampled categories, shape
(batch_size,)
.- success: torch.Tensor
Whether the sampling is successful within
max_top_k_rounds
rounds, shape(batch_size,)
.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
Examples
>>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 4 >>> vocab_size = 5 >>> max_top_k_rounds = 3 >>> top_k = 1 >>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) >>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) >>> norm_prob tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106], [0.2205, 0.0942, 0.2912, 0.3452, 0.0489], [0.2522, 0.1602, 0.2346, 0.1532, 0.2000], [0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0') >>> uniform_samples = torch.rand(max_top_k_rounds, batch_size).to(0) >>> samples, success = flashinfer.sampling.top_k_sampling_from_probs(norm_prob, uniform_samples, top_k) >>> samples tensor([3, 3, 0, 1], device='cuda:0', dtype=torch.int32) >>> success tensor([True, True, True, True], device='cuda:0')
Notes
This function expects float32 inputs, and the output is int32. We encourage users to set
max_top_k_rounds
to a reasonable value, e.g., 32. The actual implementation usually use much fewer rounds for rejection sampling because of early stopping.