KV-Cache Layout in FlashInfer#
Layout: NHD/HND#
FlashInfer provides two layouts for last 3 dimensions in KV-Cache: NHD
and HND
:
NHD
: the last 3 dimensions are organized as(seq_len, num_heads, head_dim)
.HND
: the last 3 dimensions are organized as(num_heads, seq_len, head_dim)
.
The NHD
layout is more natural because it’s consistent with the output of
\(xW_k\) and \(xW_v\) without transpose. The HND
layout is more friendly
for GPU implementation when KV-Cache uses low-precision data type (e.g. fp8).
In practice we don’t observe significant performance difference between these two layouts
on fp16 kV-Cache and we prioritize NHD
layout for better readability. FlashInfer implements
Attention kernels on both layouts and we provide an option to select between them (NHD
by default).
Ragged Tensor#
In batched inference/serving, the input sequence length may vary across different samples.
When there is no need to change the sequence length (e.g. in prefilling stage), we can use RaggedTensor
with a single ragged (variable length) dimension to store the key/value tensors in KV-Cache:
![Data structure of Ragged KV-Cache.](https://raw.githubusercontent.com/flashinfer-ai/web-data/main/tutorials/ragged.png)
The keys (or values) of all requests are packed into a single data
tensor without padding,
we use a indptr
array (num_requests+1
elements, the first element is always zero)
to store the information of variable sequence lengths of each request
(indptr[i+1]-indptr[i]
is the sequence length of request i
), the data
tensor has
shape (indptr[-1], num_heads, head_dim)
when the layout is NHD
.
We can use data[indptr[i]:indptr[i+1]]
to slice the keys (or values) of request i
.
FlashInfer APIs#
FlashInfer provides flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper
to compute
the prefill attention between queries stored in ragged tensor and keys/values stored in ragged
KV-Cache.
Mask Layout (2D Ragged Tensor)#
The aforementioned Ragged Tensor can be generalized to multiple “ragged” dimensions. For example, the attention mask in FlashInfer is a 2D ragged tensor for batch size greater than 1:
![Data structure of Mask Layout.](https://raw.githubusercontent.com/flashinfer-ai/web-data/main/tutorials/mask-layout.png)
When number of requests is greater than 1, different request might have different query length and kv length.
To avoid padding, we use a 2D ragged tensor to store attention mask. The input qo_indptr
and
kv_indptr
arrays (both with length num_requests+1
) are used to store the information of
variable sequence lengths of each request,
qo_indptr[i+1]-qo_indptr[i]
is the query length of request i
(qo_len[i]
),
kv_indptr[i+1]-kv_indptr[i]
is the kv length of request i
(kv_len[i]
).
The mask array of all requests are flattened (with query as the first dimension, and kv as last dimension)
and concatenated into a single 1D array: mask_data
. FlashInfer will create a qk_indptr
array implicitly
to store the start offset of each request’s mask in the flattened mask array: qk_indptr[1:] = cumsum(qo_len * kv_len)
.
mask_data
has shape (qk_indptr[-1],)
, we can use mask_data[qk_indptr[i]:qk_indptr[i+1]]
to slice the flattened
mask of request i
.
To save memory, we can further packes the boolean flattened boolean mask array into a bit-packed array (1 bit per element, 8 elements are packed together as a uint8) with “little” bit-order (see numpy.packbits for more details). FlashInfer accepts both boolean mask and bit-packed mask. If boolean mask is provided, FlashInfer will pack it into bit-packed array internally.
FlashInfer APIs#
flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper
and flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper
allow user to specify qo_indptr
, kv_indptr
and custom attention mask custom_mask
in begin_forward
functions,
the mask data will be added to the attention score before softmax (and after softmax scaling) in the attention kernel.
flashinfer.quantization.packbits()
and flashinfer.quantization.segment_packbits()
are the utility functions
to pack boolean mask into bit-packed array.
Page Table#
When KV-Cache is dynamic (e.g. in append or decode stage), packing all keys/values is not efficient because the sequence length per request changes over time. vLLM proposes to organize KV-Cache as a Page Table. In FlashInfer, we treat the page-table as a block sparse matrix (each used page can be viewed as an non-zero block in block sparse matrix) and uses the CSR format to index the pages in KV-Cache.
![Data structure of Paged KV-Cache.](https://raw.githubusercontent.com/flashinfer-ai/web-data/main/tutorials/page_layout.png)
For each request, we keep an record of its page_indices
, last_page_len
which
tracks the pages used by this request and the number of entries in the last page. The KV
sequence length of request i
is page_size * (len(page_indices[i]) - 1) + last_page_length[i]
.
Note
The last_page_len
of each request must be greater than zero, and less than or equal to page_size
.
The overall kv_indptr
array (with length num_requests+1
) can be computed as:
[0, len(page_indices[0]), len(page_indices[0])+len(page_indices[1]), ...]
.
The overall kv_page_indices
array (with length kv_indptr[-1]
) is the concatenation of all requests’ page_indices
.
The overall kv_last_page_lens
array (with length num_requests
) is the concatenation of all requests’ last_page_length
.
The kv_data
tensor is a 5-D tensor with shape (in NHD
layout):
(max_num_pages, 2, page_size, num_heads, head_dim)
where max_num_pages
is the maximum number of pages used by all requests, page_size
is the number of tokens
we fit into each page. 2
is the number of slots in each page (first one for keys, the second one for values).
FlashInfer APIs#
flashinfer.page.append_paged_kv_cache()
can append a batch of keys/values (stored as ragged tensors) to the paged KV-Cache
(the pages for these appended keys/values must be allocated prior to calling this API).
flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper
and flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper
implements the decode attention
and prefill/append attention between queries stored in ragged tensors and keys/values stored in paged KV-Cache.
FAQ#
- How do FlashInfer manages KV-Cache?
FlashInfer itself is not responsible for managing the page-table (pop and allocate new pages, etc.) and we leave the strategy to the user: different serving engine might have different strategies to manage the page-table. FlashInfer is only responsible for computing the attention between queries and keys/values stored in KV-Cache.