flashinfer.top_k_page_table_transform¶
- flashinfer.top_k_page_table_transform(input: Tensor, src_page_table: Tensor, lengths: Tensor, k: int, row_to_batch: Tensor | None = None) Tensor¶
Fused Top-K selection + Page Table Transform for sparse attention.
This function performs top-k selection on input scores and transforms the selected indices through a page table lookup in a single fused kernel. Used in sparse attention’s second stage where selected KV cache positions need to be mapped through page tables.
- For each row i:
output_page_table[i, j] = src_page_table[batch_idx, topk_indices[j]]
where batch_idx is determined by row_to_batch[i] if provided, otherwise i.
- Parameters:
input (torch.Tensor) – Input scores tensor of shape
(num_rows, max_len). Supported dtypes:float32,float16,bfloat16.src_page_table (torch.Tensor) – Source page table of shape
(batch_size, max_len)with dtypeint32.lengths (torch.Tensor) – Actual KV lengths per row of shape
(num_rows,)with dtypeint32.k (int) – Number of top elements to select from each row.
row_to_batch (Optional[torch.Tensor], optional) – Mapping from row index to batch index of shape
(num_rows,)with dtypeint32. If None, uses 1:1 mapping (row_idx == batch_idx). Default is None.
- Returns:
output_page_table – Output page table entries of shape
(num_rows, k)with dtypeint32. Contains the gathered page table entries for the top-k indices. Positions beyond actual length are set to -1.- Return type:
torch.Tensor
Note
This is specifically designed for sparse attention’s second stage.
If lengths[i] <= k, the output simply contains src_page_table[batch_idx, 0:lengths[i]] with remaining positions set to -1.
Examples
>>> import torch >>> import flashinfer >>> num_rows = 8 >>> max_len = 4096 >>> k = 256 >>> scores = torch.randn(num_rows, max_len, device="cuda", dtype=torch.float16) >>> src_page_table = torch.randint(0, 1000, (num_rows, max_len), device="cuda", dtype=torch.int32) >>> lengths = torch.full((num_rows,), max_len, device="cuda", dtype=torch.int32) >>> output = flashinfer.top_k_page_table_transform(scores, src_page_table, lengths, k) >>> output.shape torch.Size([8, 256])