flashinfer.top_k

flashinfer.top_k(input: Tensor, k: int, sorted: bool = False) Tuple[Tensor, Tensor]

Radix-based Top-K selection.

This function selects the top-k largest elements from each row of the input tensor. It uses an efficient radix-based selection algorithm that is particularly fast for large vocabularies.

This is designed as a drop-in replacement for torch.topk with better performance for large tensors (vocab_size > 10000).

Parameters:
  • input (torch.Tensor) – Input tensor of shape (batch_size, d) containing the values to select from. Supported dtypes: float32, float16, bfloat16.

  • k (int) – Number of top elements to select from each row.

  • sorted (bool, optional) – If True, the returned top-k elements will be sorted in descending order. Default is False (unsorted, which is faster).

Returns:

  • values (torch.Tensor) – Tensor of shape (batch_size, k) containing the top-k values. Same dtype as input.

  • indices (torch.Tensor) – Tensor of shape (batch_size, k) with int64 dtype containing the indices of the top-k elements.

Note

  • Unlike torch.topk, the default behavior returns unsorted results for better performance. Set sorted=True if you need sorted output.

  • The radix-based algorithm is O(n) in vocabulary size, compared to O(n log k) for heap-based methods, making it faster for large vocabularies.

  • For small vocabularies (< 1000), torch.topk may be faster.

Examples

>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 32000
>>> k = 256
>>> logits = torch.randn(batch_size, vocab_size, device="cuda")
>>> values, indices = flashinfer.top_k(logits, k)
>>> values.shape, indices.shape
(torch.Size([4, 256]), torch.Size([4, 256]))

With sorting enabled (for compatibility with torch.topk):

>>> values_sorted, indices_sorted = flashinfer.top_k(logits, k, sorted=True)
>>> # Values are now in descending order within each row

See also

torch.topk

PyTorch’s built-in top-k function

sampling.top_k_mask_logits

Top-k masking for logits (sets non-top-k to -inf)

sampling.top_k_renorm_probs

Top-k filtering and renormalization for probabilities