flashinfer.top_k_page_table_transform

flashinfer.top_k_page_table_transform(input: Tensor, src_page_table: Tensor, lengths: Tensor, k: int, row_to_batch: Tensor | None = None, deterministic: bool = False, tie_break: int = TopKTieBreak.NONE, dsa_graph_safe: bool = False, row_starts: Tensor | None = None) Tensor

Fused Top-K selection + Page Table Transform for sparse attention.

This function performs top-k selection on input scores and transforms the selected indices through a page table lookup in a single fused kernel. Used in sparse attention’s second stage where selected KV cache positions need to be mapped through page tables.

For each row i:

output_page_table[i, j] = src_page_table[batch_idx, topk_indices[j]]

where batch_idx is determined by row_to_batch[i] if provided, otherwise i.

Parameters:
  • input (torch.Tensor) – Input scores tensor of shape (num_rows, max_len). Supported dtypes: float32, float16, bfloat16.

  • src_page_table (torch.Tensor) – Source page table of shape (batch_size, max_len) with dtype int32.

  • lengths (torch.Tensor) – Actual KV lengths per row of shape (num_rows,) with dtype int32.

  • k (int) – Number of top elements to select from each row.

  • row_to_batch (Optional[torch.Tensor], optional) – Mapping from row index to batch index of shape (num_rows,) with dtype int32. If None, uses 1:1 mapping (row_idx == batch_idx). Default is None.

  • deterministic (bool, optional) – If True, uses deterministic mode. Default is False (non-deterministic, which is faster).

  • tie_break (int, optional) –

    Tie-breaking mode for equal values at the selection boundary. Supported modes are (or use TopKTieBreak enum values):

    • 0: no explicit index tie-break

    • 1: prefer smaller indices

    • 2: prefer larger indices

    Default is 0.

  • dsa_graph_safe (bool, optional) – If True, force FilteredTopK path and graph-safe vectorization (VEC_SIZE=1). Default is False.

  • row_starts (Optional[torch.Tensor], optional) – Per-row start indices of shape (num_rows,) with dtype int32. Top-k is computed over [row_starts[i], row_starts[i] + lengths[i]) for row i. Default is None (equivalent to all zeros).

Returns:

output_page_table – Output page table entries of shape (num_rows, k) with dtype int32. Contains the gathered page table entries for the top-k indices. Positions beyond actual length are set to -1.

Return type:

torch.Tensor

Note

  • This is specifically designed for sparse attention’s second stage.

  • If lengths[i] <= k, the output simply contains src_page_table[batch_idx, row_starts[i]:row_starts[i] + lengths[i]] (or start 0 when row_starts is None) with remaining positions set to -1.

Examples

>>> import torch
>>> import flashinfer
>>> num_rows = 8
>>> max_len = 4096
>>> k = 256
>>> scores = torch.randn(num_rows, max_len, device="cuda", dtype=torch.float16)
>>> src_page_table = torch.randint(0, 1000, (num_rows, max_len), device="cuda", dtype=torch.int32)
>>> lengths = torch.full((num_rows,), max_len, device="cuda", dtype=torch.int32)
>>> output = flashinfer.top_k_page_table_transform(scores, src_page_table, lengths, k)
>>> output.shape
torch.Size([8, 256])