flashinfer.page.get_batch_indices_positions

flashinfer.page.get_batch_indices_positions(append_indptr: torch.Tensor, seq_lens: torch.Tensor, nnz: int) Tuple[torch.Tensor, torch.Tensor]

Convert append indptr and sequence lengths to batch indices and positions.

Parameters:
  • append_indptr (torch.Tensor) – The indptr of the ragged tensor, shape: [batch_size + 1].

  • seq_lens (torch.Tensor) – The sequence lengths of each request in the KV-Cache, shape: [batch_size].

  • nnz (int) – The number of entries in the ragged tensor.

Returns:

  • batch_indices (torch.Tensor) – The batch indices of the each entry in the ragged tensor, shape: [nnz].

  • positions (torch.Tensor) – The positions of the each entry in the ragged tensor, shape: [nnz].

Example

>>> import torch
>>> import flashinfer
>>> nnz_kv = 10
>>> append_indptr = torch.tensor([0, 1, 3, 6, 10], dtype=torch.int32, device="cuda:0")
>>> seq_lens = torch.tensor([5, 5, 5, 5])
>>> batch_indices, positions = flashinfer.get_batch_indices_positions(append_indptr, seq_lens, nnz_kv)
>>> batch_indices
tensor([0, 1, 1, 2, 2, 2, 3, 3, 3, 3], device='cuda:0', dtype=torch.int32)
>>> positions  # the rightmost column index of each row
tensor([4, 3, 4, 2, 3, 4, 1, 2, 3, 4], device='cuda:0', dtype=torch.int32)

Note

This function is similar to CSR2COO conversion in cuSPARSE library, with the difference that we are converting from a ragged tensor (which don’t require a column indices array) to a COO format.