flashinfer.top_k¶
- flashinfer.top_k(input: Tensor, k: int, sorted: bool = False) Tuple[Tensor, Tensor]¶
Radix-based Top-K selection.
This function selects the top-k largest elements from each row of the input tensor. It uses an efficient radix-based selection algorithm that is particularly fast for large vocabularies.
This is designed as a drop-in replacement for
torch.topkwith better performance for large tensors (vocab_size > 10000).- Parameters:
input (torch.Tensor) – Input tensor of shape
(batch_size, d)containing the values to select from. Supported dtypes:float32,float16,bfloat16.k (int) – Number of top elements to select from each row.
sorted (bool, optional) – If True, the returned top-k elements will be sorted in descending order. Default is False (unsorted, which is faster).
- Returns:
values (torch.Tensor) – Tensor of shape
(batch_size, k)containing the top-k values. Same dtype as input.indices (torch.Tensor) – Tensor of shape
(batch_size, k)with int64 dtype containing the indices of the top-k elements.
Note
Unlike
torch.topk, the default behavior returns unsorted results for better performance. Setsorted=Trueif you need sorted output.The radix-based algorithm is O(n) in vocabulary size, compared to O(n log k) for heap-based methods, making it faster for large vocabularies.
For small vocabularies (< 1000),
torch.topkmay be faster.
Examples
>>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 4 >>> vocab_size = 32000 >>> k = 256 >>> logits = torch.randn(batch_size, vocab_size, device="cuda") >>> values, indices = flashinfer.top_k(logits, k) >>> values.shape, indices.shape (torch.Size([4, 256]), torch.Size([4, 256]))
With sorting enabled (for compatibility with torch.topk):
>>> values_sorted, indices_sorted = flashinfer.top_k(logits, k, sorted=True) >>> # Values are now in descending order within each row
See also
torch.topkPyTorch’s built-in top-k function
sampling.top_k_mask_logitsTop-k masking for logits (sets non-top-k to -inf)
sampling.top_k_renorm_probsTop-k filtering and renormalization for probabilities