flashinfer.page.append_paged_kv_cache#

flashinfer.page.append_paged_kv_cache(append_key: torch.Tensor, append_value: torch.Tensor, append_indptr: torch.Tensor, kv_data: torch.Tensor, kv_indices: torch.Tensor, kv_indptr: torch.Tensor, kv_last_page_len: torch.Tensor, kv_layout: str = 'NHD')#

Append a batch of key-value pairs to a paged key-value cache.

Parameters:
  • append_key (torch.Tensor) – The key tensor to append in ragged tensor format, shape: [append_indptr[-1], num_kv_heads, head_dim].

  • append_value (torch.Tensor) – The value tensor to append in ragged tensor format, shape: [append_indptr[-1], num_kv_heads, head_dim].

  • append_indptr (torch.Tensor) – The indptr tensor of the key-value pairs to append, shape: [batch_size + 1].

  • kv_data (torch.Tensor) – The 5-D tensor of the paged key-value cache, shape: [max_num_pages, 2, page_size, num_kv_heads, head_dim] if kv_layout is NHD, or [max_num_pages, 2, num_kv_heads, page_size, num_kv_heads] if kv_layout is NHD.

  • kv_indices (torch.Tensor) – The page indices of the paged kv-cache, shape: [kv_indptr[-1]].

  • kv_indptr (torch.Tensor) – The indptr of the paged kv-cache, shape: [batch_size + 1].

  • kv_last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv cache, shape: [batch_size].

  • kv_layout (str) – The layout of the paged kv-cache, either NHD or HND.

Example

>>> import torch
>>> import flashinfer
>>> nnz_kv = 100
>>> num_kv_heads = 32
>>> head_dim = 128
>>> k_append = torch.randn(nnz_kv, num_kv_heads, head_dim).half().to(0)
>>> v_append = torch.randn(nnz_kv, num_kv_heads, head_dim).half().to(0)
>>> # 45 + 8 + 25 + 22 = nnz_kv
>>> kv_append_length = torch.tensor([45, 8, 25, 22], dtype=torch.int32, device="cuda:0")
>>> kv_append_indptr = torch.cat(
...     [torch.zeros(1).int().to(0), torch.cumsum(kv_append_length, dim=0)]
... ).int()
>>> max_num_pages = 1000
>>> page_size = 16
>>> kv_data = torch.randn(max_num_pages, 2, page_size, num_kv_heads, head_dim).half().to(0)
>>> num_pages_per_req = torch.tensor([3, 1, 2, 2], dtype=torch.int32, device="cuda:0")
>>> kv_page_indptr = torch.cat(
...     [torch.zeros(1).int().to(0), torch.cumsum(num_pages_per_req, dim=0)]
... ).int()
>>> # use first 8 pages in the paged-kv
>>> kv_page_indices = torch.arange(8, dtype=torch.int32, device="cuda:0")
>>> # 45 = (3 - 1) * 16 + 13
>>> # 8 = (1 - 1) * 16 + 8
>>> # 25 = (2 - 1) * 16 + 9
>>> # 22 = (2 - 1) * 16 + 6
>>> kv_last_page_len = torch.tensor([13, 8, 9, 6], dtype=torch.int32, device="cuda:0")
>>>
>>> flashinfer.append_paged_kv_cache(
...     k_append,
...     v_append,
...     kv_append_indptr,
...     kv_data,
...     kv_page_indices,
...     kv_page_indptr,
...     kv_last_page_len
... )

Notes

Please refer to the tutorial for a detailed explanation of the log-sum-exp function and attention states.

The function assumes that the space for appended k/v have already been allocated, which means kv_indices, kv_indptr, kv_last_page_len has incorporated appended k/v.