flashinfer.page.append_paged_kv_cache¶
- flashinfer.page.append_paged_kv_cache(append_key: torch.Tensor, append_value: torch.Tensor, batch_indices: torch.Tensor, positions: torch.Tensor, paged_kv_cache: torch.Tensor | Tuple[torch.Tensor, torch.Tensor], kv_indices: torch.Tensor, kv_indptr: torch.Tensor, kv_last_page_len: torch.Tensor, kv_layout: str = 'NHD') None ¶
Append a batch of key-value pairs to a paged key-value cache.
- Parameters:
append_key (torch.Tensor) – The key tensor to append in ragged tensor format, shape:
[append_indptr[-1], num_kv_heads, head_dim]
.append_value (torch.Tensor) – The value tensor to append in ragged tensor format, shape:
[append_indptr[-1], num_kv_heads, head_dim]
.batch_indices (torch.Tensor) – The batch indices of the each entry in the appended key-value pairs, shape:
[append_indptr[-1]]
.positions (torch.Tensor) – The positions of the each entry in the appended key-value pairs, shape:
[append_indptr[-1]]
.paged_kv_cache (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) –
The paged KV-Cache stored as a tuple of tensors or a single tensor:
a tuple
(k_cache, v_cache)
of 4-D tensors, each with shape:[max_num_pages, page_size, num_kv_heads, head_dim]
ifkv_layout
isNHD
, and[max_num_pages, num_kv_heads, page_size, head_dim]
ifkv_layout
isHND
.a single 5-D tensor with shape:
[max_num_pages, 2, page_size, num_kv_heads, head_dim]
ifkv_layout
isNHD
, and[max_num_pages, 2, num_kv_heads, page_size, head_dim]
ifkv_layout
isHND
. Wherepaged_kv_cache[:, 0]
is the key-cache andpaged_kv_cache[:, 1]
is the value-cache.
kv_indices (torch.Tensor) – The page indices of the paged kv-cache, shape:
[kv_indptr[-1]]
.kv_indptr (torch.Tensor) – The indptr of the paged kv-cache, shape:
[batch_size + 1]
.kv_last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv cache, shape:
[batch_size]
.kv_layout (str) – The layout of the paged kv-cache, either
NHD
orHND
.
Example
>>> import torch >>> import flashinfer >>> nnz_kv = 100 >>> num_kv_heads = 32 >>> head_dim = 128 >>> k_append = torch.randn(nnz_kv, num_kv_heads, head_dim).half().to(0) >>> v_append = torch.randn(nnz_kv, num_kv_heads, head_dim).half().to(0) >>> # 45 + 8 + 25 + 22 = nnz_kv >>> kv_append_length = torch.tensor([45, 8, 25, 22], dtype=torch.int32, device="cuda:0") >>> kv_append_indptr = torch.cat( ... [torch.zeros(1).int().to(0), torch.cumsum(kv_append_length, dim=0)] ... ).int() # [0, 45, 53, 78, 100] >>> max_num_pages = 1000 >>> page_size = 16 >>> paged_kv_cache = torch.randn(max_num_pages, 2, page_size, num_kv_heads, head_dim).half().to(0) >>> num_pages_per_req = torch.tensor([3, 1, 2, 2], dtype=torch.int32, device="cuda:0") >>> kv_page_indptr = torch.cat( ... [torch.zeros(1).int().to(0), torch.cumsum(num_pages_per_req, dim=0)] ... ).int() >>> # use first 8 pages in the paged-kv >>> kv_page_indices = torch.arange(8, dtype=torch.int32, device="cuda:0") >>> # 45 = (3 - 1) * 16 + 13 >>> # 8 = (1 - 1) * 16 + 8 >>> # 25 = (2 - 1) * 16 + 9 >>> # 22 = (2 - 1) * 16 + 6 >>> kv_last_page_len = torch.tensor([13, 8, 9, 6], dtype=torch.int32, device="cuda:0") >>> batch_indices, positions = flashinfer.get_batch_indices_positions( ... kv_append_indptr, flashinfer.get_seq_lens(kv_page_indptr, kv_last_page_len, page_size), nnz_kv ... ) >>> batch_indices tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0', dtype=torch.int32) >>> positions tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], device='cuda:0', dtype=torch.int32) >>> flashinfer.append_paged_kv_cache( ... k_append, ... v_append, ... batch_indices, ... positions, ... paged_kv_cache, ... kv_page_indices, ... kv_page_indptr, ... kv_last_page_len ... )
Note
The function assumes that the space for appended k/v have already been allocated, which means
kv_indices
,kv_indptr
,kv_last_page_len
has incorporated appended k/v.See also