flashinfer.page.append_paged_mla_kv_cache¶
- flashinfer.page.append_paged_mla_kv_cache(append_ckv: torch.Tensor, append_kpe: torch.Tensor, batch_indices: torch.Tensor, positions: torch.Tensor, ckv_cache: torch.Tensor | None, kpe_cache: torch.Tensor | None, kv_indices: torch.Tensor, kv_indptr: torch.Tensor, kv_last_page_len: torch.Tensor) None ¶
Append a batch of key-value pairs to a paged key-value cache, Note: current only support ckv=512 and kpe=64
- Parameters:
append_ckv (torch.Tensor) – The compressed kv tensor to append in ragged tensor format, shape:
[append_indptr[-1], ckv_dim]
.append_kpe (torch.Tensor) – The value tensor to append in ragged tensor format, shape:
[append_indptr[-1], kpe_dim]
.batch_indices (torch.Tensor) – The batch indices of the each entry in the appended key-value pairs, shape:
[append_indptr[-1]]
.positions (torch.Tensor) – The positions of the each entry in the appended key-value pairs, shape:
[append_indptr[-1]]
.ckv_cache (cache for compressed kv, torch.Tensor, shape: [page_num, page_size, ckv_dim])
kpe_cache (cache for key position embedding, torch.Tensor, shape: [page_num, page_size, kpe_dim])
kv_indices (torch.Tensor) – The page indices of the paged kv-cache, shape:
[kv_indptr[-1]]
.kv_indptr (torch.Tensor) – The indptr of the paged kv-cache, shape:
[batch_size + 1]
.kv_last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv cache, shape:
[batch_size]
.