flashinfer.top_k_ragged_transform

flashinfer.top_k_ragged_transform(input: Tensor, offsets: Tensor, lengths: Tensor, k: int) Tensor

Fused Top-K selection + Ragged Index Transform for sparse attention.

This function performs top-k selection on input scores and transforms the selected indices by adding an offset in a single fused kernel. Used in sparse attention’s second stage with ragged/variable-length KV cache.

For each row i:

output_indices[i, j] = topk_indices[j] + offsets[i]

Parameters:
  • input (torch.Tensor) – Input scores tensor of shape (num_rows, max_len). Supported dtypes: float32, float16, bfloat16.

  • offsets (torch.Tensor) – Offset to add per row of shape (num_rows,) with dtype int32.

  • lengths (torch.Tensor) – Actual KV lengths per row of shape (num_rows,) with dtype int32.

  • k (int) – Number of top elements to select from each row.

Returns:

output_indices – Output indices of shape (num_rows, k) with dtype int32. Contains the top-k indices plus offsets. Positions beyond actual length are set to -1.

Return type:

torch.Tensor

Note

  • This is specifically designed for sparse attention’s second stage with ragged KV cache layout.

  • If lengths[i] <= k, the output contains [offsets[i], offsets[i]+1, …, offsets[i]+lengths[i]-1] with remaining positions set to -1.

Examples

>>> import torch
>>> import flashinfer
>>> num_rows = 8
>>> max_len = 4096
>>> k = 256
>>> scores = torch.randn(num_rows, max_len, device="cuda", dtype=torch.float16)
>>> offsets = torch.arange(0, num_rows * max_len, max_len, device="cuda", dtype=torch.int32)
>>> lengths = torch.full((num_rows,), max_len, device="cuda", dtype=torch.int32)
>>> output = flashinfer.top_k_ragged_transform(scores, offsets, lengths, k)
>>> output.shape
torch.Size([8, 256])