flashinfer.sampling

Kernels for LLM sampling.

See also

For efficient Top-K selection (without sampling), see flashinfer.topk which provides top_k(), top_k_page_table_transform(), and top_k_ragged_transform().

sampling_from_probs(probs[, indices, ...])

Fused GPU kernel for category sampling from probabilities.

sampling_from_logits(logits[, indices, ...])

Fused GPU kernel for category sampling from logits. It's equivalent to sampling from logits after applying softmax. :param logits: Logits for sampling. When indices is not provided, shape should be (batch_size, num_classes) and the i-th output will be sampled from the i-th row of logits. When indices is provided, shape should be (unique_batch_size, num_classes) where unique_batch_size is the number of unique probability distributions. :type logits: torch.Tensor :param indices: Optional indices tensor of shape (batch_size,), dtype torch.int32 or torch.int64 that maps each output to a row in logits. The output tensor will have the same dtype as indices. For example, if indices[i] = j, then the i-th output will be sampled from logits[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of logits and output dtype defaults to torch.int32. :type indices: Optional[torch.Tensor] :param deterministic: Since the sampling doesn't use cub's BlockScan, the sampling is deterministic. We keep this argument for compatibility with other sampling functions. :type deterministic: bool :param generator: A random number generator for the operation. :type generator: Optional[torch.Generator] :param check_nan: Whether to check nan in logits, default is False. :type check_nan: bool :param seed: Random seed value for the sampling operation. Can be either an integer or a torch.Tensor. When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. Using torch.Tensor is required for CUDA graph compatibility.

softmax(logits[, temperature, enable_pdl])

Fused GPU kernel for online safe softmax with temperature scaling.

top_p_sampling_from_probs(probs, top_p[, ...])

Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities, this operator implements GPU-based rejection sampling without explicit sorting.

top_k_sampling_from_probs(probs, top_k[, ...])

Fused GPU kernel for top-k sampling from probabilities, this operator implements GPU-based rejection sampling without explicit sorting.

min_p_sampling_from_probs(probs, min_p[, ...])

Fused GPU kernel for min_p sampling from probabilities,

top_k_top_p_sampling_from_logits(logits, ...)

Fused GPU kernel for top-k and top-p sampling from pre-softmax logits,

top_k_top_p_sampling_from_probs(probs, ...)

Fused GPU kernel for top-k and top-p sampling from probabilities,

top_p_renorm_probs(probs, top_p[, ...])

Fused GPU kernel for renormalizing probabilities by top-p thresholding.

top_k_renorm_probs(probs, top_k)

Fused GPU kernel for renormalizing probabilities by top-k thresholding.

top_k_mask_logits(logits, top_k)

Fused GPU kernel for masking logits by top-k thresholding.

chain_speculative_sampling(draft_probs, ...)

Fused-GPU kernel for speculative sampling for sequence generation (proposed in paper Accelerating Large Language Model Decoding with Speculative Sampling), where the draft model generates a sequence(chain) of tokens for each request.