flashinfer.page.append_paged_kv_cache

flashinfer.page.append_paged_kv_cache(append_key: torch.Tensor, append_value: torch.Tensor, batch_indices: torch.Tensor, positions: torch.Tensor, paged_kv_cache: torch.Tensor | Tuple[torch.Tensor, torch.Tensor], kv_indices: torch.Tensor, kv_indptr: torch.Tensor, kv_last_page_len: torch.Tensor, kv_layout: str = 'NHD') None

Append a batch of key-value pairs to a paged key-value cache.

Parameters:
  • append_key (torch.Tensor) – The key tensor to append in ragged tensor format, shape: [append_indptr[-1], num_kv_heads, head_dim].

  • append_value (torch.Tensor) – The value tensor to append in ragged tensor format, shape: [append_indptr[-1], num_kv_heads, head_dim].

  • batch_indices (torch.Tensor) – The batch indices of the each entry in the appended key-value pairs, shape: [append_indptr[-1]].

  • positions (torch.Tensor) – The positions of the each entry in the appended key-value pairs, shape: [append_indptr[-1]].

  • paged_kv_cache (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) –

    The paged KV-Cache stored as a tuple of tensors or a single tensor:

    • a tuple (k_cache, v_cache) of 4-D tensors, each with shape: [max_num_pages, page_size, num_kv_heads, head_dim] if kv_layout is NHD, and [max_num_pages, num_kv_heads, page_size, head_dim] if kv_layout is HND.

    • a single 5-D tensor with shape: [max_num_pages, 2, page_size, num_kv_heads, head_dim] if kv_layout is NHD, and [max_num_pages, 2, num_kv_heads, page_size, head_dim] if kv_layout is HND. Where paged_kv_cache[:, 0] is the key-cache and paged_kv_cache[:, 1] is the value-cache.

  • kv_indices (torch.Tensor) – The page indices of the paged kv-cache, shape: [kv_indptr[-1]].

  • kv_indptr (torch.Tensor) – The indptr of the paged kv-cache, shape: [batch_size + 1].

  • kv_last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv cache, shape: [batch_size].

  • kv_layout (str) – The layout of the paged kv-cache, either NHD or HND.

Example

>>> import torch
>>> import flashinfer
>>> nnz_kv = 100
>>> num_kv_heads = 32
>>> head_dim = 128
>>> k_append = torch.randn(nnz_kv, num_kv_heads, head_dim).half().to(0)
>>> v_append = torch.randn(nnz_kv, num_kv_heads, head_dim).half().to(0)
>>> # 45 + 8 + 25 + 22 = nnz_kv
>>> kv_append_length = torch.tensor([45, 8, 25, 22], dtype=torch.int32, device="cuda:0")
>>> kv_append_indptr = torch.cat(
...     [torch.zeros(1).int().to(0), torch.cumsum(kv_append_length, dim=0)]
... ).int()  # [0, 45, 53, 78, 100]
>>> max_num_pages = 1000
>>> page_size = 16
>>> paged_kv_cache = torch.randn(max_num_pages, 2, page_size, num_kv_heads, head_dim).half().to(0)
>>> num_pages_per_req = torch.tensor([3, 1, 2, 2], dtype=torch.int32, device="cuda:0")
>>> kv_page_indptr = torch.cat(
...     [torch.zeros(1).int().to(0), torch.cumsum(num_pages_per_req, dim=0)]
... ).int()
>>> # use first 8 pages in the paged-kv
>>> kv_page_indices = torch.arange(8, dtype=torch.int32, device="cuda:0")
>>> # 45 = (3 - 1) * 16 + 13
>>> # 8 = (1 - 1) * 16 + 8
>>> # 25 = (2 - 1) * 16 + 9
>>> # 22 = (2 - 1) * 16 + 6
>>> kv_last_page_len = torch.tensor([13, 8, 9, 6], dtype=torch.int32, device="cuda:0")
>>> batch_indices, positions = flashinfer.get_batch_indices_positions(
...     kv_append_indptr, flashinfer.get_seq_lens(kv_page_indptr, kv_last_page_len, page_size), nnz_kv
... )
>>> batch_indices
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
        1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3], device='cuda:0', dtype=torch.int32)
>>> positions
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44,  0,  1,  2,  3,  4,  5,  6,  7,  0,
        1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
        19, 20, 21, 22, 23, 24,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,
        12, 13, 14, 15, 16, 17, 18, 19, 20, 21], device='cuda:0',
    dtype=torch.int32)
>>> flashinfer.append_paged_kv_cache(
...     k_append,
...     v_append,
...     batch_indices,
...     positions,
...     paged_kv_cache,
...     kv_page_indices,
...     kv_page_indptr,
...     kv_last_page_len
... )

Note

The function assumes that the space for appended k/v have already been allocated, which means kv_indices, kv_indptr, kv_last_page_len has incorporated appended k/v.