KV-Cache Layout in FlashInfer#
Layout: NHD/HND#
FlashInfer provides two layouts for last 3 dimensions in KV-Cache: NHD
and HND
:
NHD
: the last 3 dimensions are organized as(seq_len, num_heads, head_dim)
.HND
: the last 3 dimensions are organized as(num_heads, seq_len, head_dim)
.
The NHD
layout is more natural because it’s consistent with the output of
\(xW_k\) and \(xW_v\) without transpose. The HND
layout is more friendly
for GPU implementation when KV-Cache uses low-precision data type (e.g. fp8).
In practice we don’t observe significant performance difference between these two layouts
on fp16 kV-Cache and we prioritize NHD
layout for better readability. FlashInfer implements
Attention kernels on both layouts and we provide an option to select between them (NHD
by default).
Ragged Tensor#
In batched inference/serving, the input sequence length may vary across different samples.
When there is no need to change the sequence length (e.g. in prefilling stage), we can use RaggedTensor
with a single ragged (variable length) dimension to store the key/value tensors in KV-Cache:
The keys (or values) of all requests are packed into a single data
tensor without padding,
we use a indptr
array (num_requests+1
elements, the first element is always zero)
to store the information of variable sequence lengths of each request
(indptr[i+1]-indptr[i]
is the sequence length of request i
), the data
tensor has
shape (indptr[-1], num_heads, head_dim)
when the layout is NHD
.
We can use data[indptr[i]:indptr[i+1]]
to slice the keys (or values) of request i
.
FlashInfer APIs#
FlashInfer provides flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper
to compute
the prefill attention between queries stored in ragged tensor and keys/values stored in ragged
KV-Cache.
Mask Layout (2D Ragged Tensor)#
The aforementioned Ragged Tensor can be generalized to multiple “ragged” dimensions. For example, the attention mask in FlashInfer is a 2D ragged tensor for batch size greater than 1:
When number of requests is greater than 1, different request might have different query length and kv length.
To avoid padding, we use a 2D ragged tensor to store attention mask. The input qo_indptr
and
kv_indptr
arrays (both with length num_requests+1
) are used to store the information of
variable sequence lengths of each request,
qo_indptr[i+1]-qo_indptr[i]
is the query length of request i
(qo_len[i]
),
kv_indptr[i+1]-kv_indptr[i]
is the kv length of request i
(kv_len[i]
).
The mask array of all requests are flattened (with query as the first dimension, and kv as last dimension)
and concatenated into a single 1D array: mask_data
. FlashInfer will create a qk_indptr
array implicitly
to store the start offset of each request’s mask in the flattened mask array: qk_indptr[1:] = cumsum(qo_len * kv_len)
.
mask_data
has shape (qk_indptr[-1],)
, we can use mask_data[qk_indptr[i]:qk_indptr[i+1]]
to slice the flattened
mask of request i
.
To save memory, we can further packes the boolean flattened boolean mask array into a bit-packed array (1 bit per element, 8 elements are packed together as a uint8) with “little” bit-order (see numpy.packbits for more details). FlashInfer accepts both boolean mask and bit-packed mask. If boolean mask is provided, FlashInfer will pack it into bit-packed array internally.
FlashInfer APIs#
flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper
and flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper
allow user to specify qo_indptr
, kv_indptr
and custom attention mask custom_mask
in begin_forward
functions,
the mask data will be added to the attention score before softmax (and after softmax scaling) in the attention kernel.
flashinfer.quantization.packbits()
and flashinfer.quantization.segment_packbits()
are the utility functions
to pack boolean mask into bit-packed array.
Page Table Layout#
When KV-Cache is dynamic (e.g. in append or decode stage), packing all keys/values is not efficient because the sequence length per request changes over time. vLLM proposes to organize KV-Cache as a Page Table. In FlashInfer, we treat the page-table as a block sparse matrix (each used page can be viewed as an non-zero block in block sparse matrix) and uses the CSR format to index the pages in KV-Cache.
For each request, we keep an record of its page_indices
, last_page_len
which
tracks the pages used by this request and the number of entries in the last page. The KV
sequence length of request i
is page_size * (len(page_indices[i]) - 1) + last_page_length[i]
.
Note
The last_page_len
of each request must be greater than zero, and less than or equal to page_size
.
The overall kv_indptr
array (with length num_requests+1
) can be computed as:
[0, len(page_indices[0]), len(page_indices[0])+len(page_indices[1]), ...]
.
The overall kv_page_indices
array (with length kv_indptr[-1]
) is the concatenation of all requests’ page_indices
.
The overall kv_last_page_lens
array (with length num_requests
) is the concatenation of all requests’ last_page_length
.
The kv_data
tensor could either be a single 5-D tensor or a tuple of 4-D tensors,
when stored in a single tensor, kv_data
has shape:
(max_num_pages, 2, page_size, num_heads, head_dim) # NHD layout
(max_num_pages, 2, num_heads, page_size, head_dim) # HND layout
when stored in a tuple of tensors, kv_data = (k_data, v_data)
, and each one of them has shape:
(max_num_pages, page_size, num_heads, head_dim) # NHD layout
(max_num_pages, num_heads, page_size, head_dim) # HND layout
where max_num_pages
is the maximum number of pages used by all requests, page_size
is the number of tokens
we fit into each page. 2
in single tensor storage means K/V (first one for keys, the second one for values).
FlashInfer APIs#
flashinfer.page.append_paged_kv_cache()
can append a batch of keys/values (stored as ragged tensors) to the paged KV-Cache
(the pages for these appended keys/values must be allocated prior to calling this API).
flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper
and flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper
implements the decode attention
and prefill/append attention between queries stored in ragged tensors and keys/values stored in paged KV-Cache.
Multi-level Cascade Inference Data Layout#
When using multi-level cascade inference,
the query and output are stored in ragged tensors, and KV-Cache of all levels are stored
in a unified Paged KV-Cache. Each level has a unique qo_indptr
array which is the prefix sum of the
accumulated number of tokens to append in the subtree, as well as kv_page_indptr
, kv_page_indices
, and
kv_last_page_len
which has same semantics as in Page Table Layout section. The following figure
introduce how to construct these data structures for append attention operation for 8 requests where we
treat their KV-Cache as 3 levels for prefix reuse:
Note that we don’t have to change the data layout of ragged query/output tensor or paged kv-cache for each level.
All levels share the same underlying data layout, but we use different qo_indptr
/ kv_page_indptr
arrays
so that we can view them in different ways.
FlashInfer APIs#
FlashInfer provides flashinfer.cascade.MultiLevelCascadeAttentionWrapper
to compute
the cascade attention.
FAQ#
- How do FlashInfer manages KV-Cache?
FlashInfer itself is not responsible for managing the page-table (pop and allocate new pages, etc.) and we leave the strategy to the user: different serving engine might have different strategies to manage the page-table. FlashInfer is only responsible for computing the attention between queries and keys/values stored in KV-Cache.