flashinfer.page.get_batch_indices_positions#
- flashinfer.page.get_batch_indices_positions(append_indptr: torch.Tensor, seq_lens: torch.Tensor, nnz: int) Tuple[torch.Tensor, torch.Tensor] #
Convert append indptr and sequence lengths to batch indices and positions.
- Parameters:
append_indptr (torch.Tensor) – The indptr of the ragged tensor, shape:
[batch_size + 1]
.seq_lens (torch.Tensor) – The sequence lengths of each request in the KV-Cache, shape:
[batch_size]
.nnz (int) – The number of entries in the ragged tensor.
- Returns:
batch_indices (torch.Tensor) – The batch indices of the each entry in the ragged tensor, shape:
[nnz]
.positions (torch.Tensor) – The positions of the each entry in the ragged tensor, shape:
[nnz]
.
Example
>>> import torch >>> import flashinfer >>> nnz_kv = 10 >>> append_indptr = torch.tensor([0, 1, 3, 6, 10], dtype=torch.int32, device="cuda:0") >>> seq_lens = torch.tensor([5, 5, 5, 5]) >>> batch_indices, positions = flashinfer.get_batch_indices_positions(append_indptr, seq_lens, nnz_kv) >>> batch_indices tensor([0, 1, 1, 2, 2, 2, 3, 3, 3, 3], device='cuda:0', dtype=torch.int32) >>> positions # the rightmost column index of each row tensor([4, 3, 4, 2, 3, 4, 1, 2, 3, 4], device='cuda:0', dtype=torch.int32)
Note
This function is similar to CSR2COO conversion in cuSPARSE library, with the difference that we are converting from a ragged tensor (which don’t require a column indices array) to a COO format.
See also