flashinfer.top_k_ragged_transform¶
- flashinfer.top_k_ragged_transform(input: Tensor, offsets: Tensor, lengths: Tensor, k: int, deterministic: bool = False, tie_break: int = TopKTieBreak.NONE, dsa_graph_safe: bool = False, row_starts: Tensor | None = None) Tensor¶
Fused Top-K selection + Ragged Index Transform for sparse attention.
This function performs top-k selection on input scores and transforms the selected indices by adding an offset in a single fused kernel. Used in sparse attention’s second stage with ragged/variable-length KV cache.
- For each row i:
output_indices[i, j] = topk_indices[j] + offsets[i]
- Parameters:
input (torch.Tensor) – Input scores tensor of shape
(num_rows, max_len). Supported dtypes:float32,float16,bfloat16.offsets (torch.Tensor) – Offset to add per row of shape
(num_rows,)with dtypeint32.lengths (torch.Tensor) – Actual KV lengths per row of shape
(num_rows,)with dtypeint32.k (int) – Number of top elements to select from each row.
deterministic (bool, optional) – If True, uses deterministic mode. Default is False (non-deterministic, which is faster).
tie_break (int, optional) –
Tie-breaking mode for equal values at the selection boundary. Supported modes are (or use
TopKTieBreakenum values):0: no explicit index tie-break1: prefer smaller indices2: prefer larger indices
Default is
0.dsa_graph_safe (bool, optional) – If True, force FilteredTopK path and graph-safe vectorization (VEC_SIZE=1). Default is False.
row_starts (Optional[torch.Tensor], optional) – Per-row start indices of shape
(num_rows,)with dtypeint32. Top-k is computed over[row_starts[i], row_starts[i] + lengths[i])for rowi. Output indices remainlocal_topk + offsets[i]wherelocal_topkis relative torow_starts[i]. Default is None (equivalent to all zeros).
- Returns:
output_indices – Output indices of shape
(num_rows, k)with dtypeint32. Contains the top-k indices plus offsets. Positions beyond actual length are set to -1.- Return type:
torch.Tensor
Note
This is specifically designed for sparse attention’s second stage with ragged KV cache layout.
If lengths[i] <= k, the output contains [offsets[i], offsets[i]+1, …, offsets[i]+lengths[i]-1] with remaining positions set to -1.
Examples
>>> import torch >>> import flashinfer >>> num_rows = 8 >>> max_len = 4096 >>> k = 256 >>> scores = torch.randn(num_rows, max_len, device="cuda", dtype=torch.float16) >>> offsets = torch.arange(0, num_rows * max_len, max_len, device="cuda", dtype=torch.int32) >>> lengths = torch.full((num_rows,), max_len, device="cuda", dtype=torch.int32) >>> output = flashinfer.top_k_ragged_transform(scores, offsets, lengths, k) >>> output.shape torch.Size([8, 256])