flashinfer.sampling.sampling_from_logits¶
- flashinfer.sampling.sampling_from_logits(logits: Tensor, indices: Tensor | None = None, deterministic: bool = True, generator: Generator | None = None, check_nan: bool = False, seed: int | Tensor | None = None, offset: int | Tensor | None = None) Tensor¶
Fused GPU kernel for category sampling from logits. It’s equivalent to sampling from
logitsafter applying softmax. :param logits: Logits for sampling. When indices is not provided, shape should be(batch_size, num_classes)and the i-th output will be sampled from the i-th row of logits. When indices is provided, shape should be
(unique_batch_size, num_classes)where unique_batch_size is the number of unique probability distributions.- Parameters:
indices (Optional[torch.Tensor]) – Optional indices tensor of shape
(batch_size,), dtypetorch.int32ortorch.int64that maps each output to a row in logits. The output tensor will have the same dtype as indices. For example, if indices[i] = j, then the i-th output will be sampled from logits[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of logits and output dtype defaults totorch.int32.deterministic (bool) – Since the sampling doesn’t use cub’s BlockScan, the sampling is deterministic. We keep this argument for compatibility with other sampling functions.
generator (Optional[torch.Generator]) – A random number generator for the operation.
check_nan (bool) – Whether to check nan in
logits, default isFalse.seed (Optional[Union[int, torch.Tensor]]) –
Random seed value for the sampling operation. Can be either an integer or a torch.Tensor. When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. Using torch.Tensor is required for CUDA graph compatibility.
Warning: If you provide seed and offset explicitly, you are responsible for updating their values between calls to ensure different random samples. Common approaches include: - Incrementing offset by the number of random values consumed - Updating seed based on the number of calls to the operation
offset (Optional[Union[int, torch.Tensor]]) –
Random offset value for the sampling operation. Can be either an integer or a torch.Tensor. When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. Using torch.Tensor is required for CUDA graph compatibility.
Warning: If you provide seed and offset explicitly, you are responsible for updating their values between calls to ensure different random samples. The offset should be incremented based on the number of random values consumed by the operation.
- Returns:
samples – Sampled categories, shape (batch_size,). It’s equivalent to sampling from
logitsafter applying softmax.- Return type:
torch.Tensor
Examples
>>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 4 >>> vocab_size = 5 >>> logits = torch.rand(batch_size, vocab_size).to(0) >>> logits tensor([[0.8823, 0.9150, 0.3829, 0.9593, 0.3904], [0.6009, 0.2566, 0.7936, 0.9408, 0.1332], [0.9346, 0.5936, 0.8694, 0.5677, 0.7411], [0.4294, 0.8854, 0.5739, 0.2666, 0.6274]], device='cuda:0') >>> samples = flashinfer.sampling.sampling_from_logits(logits) >>> samples tensor([0, 1, 1, 1], device='cuda:0', dtype=torch.int32)