flashinfer.page.append_paged_mla_kv_cache

flashinfer.page.append_paged_mla_kv_cache(append_ckv: torch.Tensor, append_kpe: torch.Tensor, batch_indices: torch.Tensor, positions: torch.Tensor, ckv_cache: torch.Tensor | None, kpe_cache: torch.Tensor | None, kv_indices: torch.Tensor, kv_indptr: torch.Tensor, kv_last_page_len: torch.Tensor) None

Append a batch of key-value pairs to a paged key-value cache, Note: current only support ckv=512 and kpe=64

Parameters:
  • append_ckv (torch.Tensor) – The compressed kv tensor to append in ragged tensor format, shape: [append_indptr[-1], ckv_dim].

  • append_kpe (torch.Tensor) – The value tensor to append in ragged tensor format, shape: [append_indptr[-1], kpe_dim].

  • batch_indices (torch.Tensor) – The batch indices of the each entry in the appended key-value pairs, shape: [append_indptr[-1]].

  • positions (torch.Tensor) – The positions of the each entry in the appended key-value pairs, shape: [append_indptr[-1]].

  • ckv_cache (cache for compressed kv, torch.Tensor, shape: [page_num, page_size, ckv_dim])

  • kpe_cache (cache for key position embedding, torch.Tensor, shape: [page_num, page_size, kpe_dim])

  • kv_indices (torch.Tensor) – The page indices of the paged kv-cache, shape: [kv_indptr[-1]].

  • kv_indptr (torch.Tensor) – The indptr of the paged kv-cache, shape: [batch_size + 1].

  • kv_last_page_len (torch.Tensor) – The number of entries in the last page of each request in the paged kv cache, shape: [batch_size].