KV-Cache Layout in FlashInfer¶
Layout: NHD/HND¶
FlashInfer provides two layouts for last 3 dimensions in KV-Cache: NHD
and HND
:
NHD
: the last 3 dimensions are organized as(seq_len, num_heads, head_dim)
.HND
: the last 3 dimensions are organized as(num_heads, seq_len, head_dim)
.
The NHD
layout is more natural because it’s consistent with the output of
\(xW_k\) and \(xW_v\) without transpose. The HND
layout is more friendly
for GPU implementation when KV-Cache uses low-precision data type (e.g. fp8).
In practice we don’t observe significant performance difference between these two layouts
on fp16 kV-Cache and we prioritize NHD
layout for better readability. FlashInfer implements
Attention kernels on both layouts and we provide an option to select between them (NHD
by default).
Ragged Tensor¶
In batched inference/serving, the input sequence length may vary across different samples.
When there is no need to change the sequence length (e.g. in prefilling stage), we can use RaggedTensor
with a single ragged (variable length) dimension to store the key/value tensors in KV-Cache:
The keys (or values) of all requests are packed into a single data
tensor without padding,
we use a indptr
array (num_requests+1
elements, the first element is always zero)
to store the information of variable sequence lengths of each request
(indptr[i+1]-indptr[i]
is the sequence length of request i
), the data
tensor has
shape (indptr[-1], num_heads, head_dim)
when the layout is NHD
.
We can use data[indptr[i]:indptr[i+1]]
to slice the keys (or values) of request i
.
FlashInfer APIs¶
FlashInfer provides flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper
to compute
the prefill attention between queries stored in ragged tensor and keys/values stored in ragged
KV-Cache.
Mask Layout (2D Ragged Tensor)¶
The aforementioned Ragged Tensor can be generalized to multiple “ragged” dimensions. For example, the attention mask in FlashInfer is a 2D ragged tensor for batch size greater than 1:
When number of requests is greater than 1, different request might have different query length and kv length.
To avoid padding, we use a 2D ragged tensor to store attention mask. The input qo_indptr
and
kv_indptr
arrays (both with length num_requests+1
) are used to store the information of
variable sequence lengths of each request,
qo_indptr[i+1]-qo_indptr[i]
is the query length of request i
(qo_len[i]
),
kv_indptr[i+1]-kv_indptr[i]
is the kv length of request i
(kv_len[i]
).
The mask array of all requests are flattened (with query as the first dimension, and kv as last dimension)
and concatenated into a single 1D array: mask_data
. FlashInfer will create a qk_indptr
array implicitly
to store the start offset of each request’s mask in the flattened mask array: qk_indptr[1:] = cumsum(qo_len * kv_len)
.
mask_data
has shape (qk_indptr[-1],)
, we can use mask_data[qk_indptr[i]:qk_indptr[i+1]]
to slice the flattened
mask of request i
.
To save memory, we can further packes the boolean flattened boolean mask array into a bit-packed array (1 bit per element, 8 elements are packed together as a uint8) with “little” bit-order (see numpy.packbits for more details). FlashInfer accepts both boolean mask and bit-packed mask. If boolean mask is provided, FlashInfer will pack it into bit-packed array internally.
FlashInfer APIs¶
flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper
and flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper
allow user to specify qo_indptr
, kv_indptr
and custom attention mask custom_mask
in begin_forward
functions,
the mask data will be added to the attention score before softmax (and after softmax scaling) in the attention kernel.
flashinfer.quantization.packbits()
and flashinfer.quantization.segment_packbits()
are the utility functions
to pack boolean mask into bit-packed array.
Page Table Layout¶
When KV-Cache is dynamic (e.g. in append or decode stage), packing all keys/values is not efficient because the sequence length per request changes over time. vLLM proposes to organize KV-Cache as a Page Table. In FlashInfer, we treat the page-table as a block sparse matrix (each used page can be viewed as an non-zero block in block sparse matrix) and uses the CSR format to index the pages in KV-Cache.
For each request, we keep an record of its page_indices
, last_page_len
which
tracks the pages used by this request and the number of entries in the last page. The KV
sequence length of request i
is page_size * (len(page_indices[i]) - 1) + last_page_length[i]
.
Note
The last_page_len
of each request must be greater than zero, and less than or equal to page_size
.
The overall kv_indptr
array (with length num_requests+1
) can be computed as:
[0, len(page_indices[0]), len(page_indices[0])+len(page_indices[1]), ...]
.
The overall kv_page_indices
array (with length kv_indptr[-1]
) is the concatenation of all requests’ page_indices
.
The overall kv_last_page_lens
array (with length num_requests
) is the concatenation of all requests’ last_page_length
.
The kv_data
tensor could either be a single 5-D tensor or a tuple of 4-D tensors,
when stored in a single tensor, kv_data
has shape:
(max_num_pages, 2, page_size, num_heads, head_dim) # NHD layout
(max_num_pages, 2, num_heads, page_size, head_dim) # HND layout
when stored in a tuple of tensors, kv_data = (k_data, v_data)
, and each one of them has shape:
(max_num_pages, page_size, num_heads, head_dim) # NHD layout
(max_num_pages, num_heads, page_size, head_dim) # HND layout
where max_num_pages
is the maximum number of pages used by all requests, page_size
is the number of tokens
we fit into each page. 2
in single tensor storage means K/V (first one for keys, the second one for values).
FlashInfer APIs¶
flashinfer.page.append_paged_kv_cache()
can append a batch of keys/values (stored as ragged tensors) to the paged KV-Cache
(the pages for these appended keys/values must be allocated prior to calling this API).
flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper
and flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper
implements the decode attention
and prefill/append attention between queries stored in ragged tensors and keys/values stored in paged KV-Cache.
Multi-level Cascade Inference Data Layout¶
When using multi-level cascade inference,
the query and output are stored in ragged tensors, and KV-Cache of all levels are stored
in a unified Paged KV-Cache. Each level has a unique qo_indptr
array which is the prefix sum of the
accumulated number of tokens to append in the subtree, as well as kv_page_indptr
, kv_page_indices
, and
kv_last_page_len
which has same semantics as in Page Table Layout section. The following figure
introduce how to construct these data structures for append attention operation for 8 requests where we
treat their KV-Cache as 3 levels for prefix reuse:
Note that we don’t have to change the data layout of ragged query/output tensor or paged kv-cache for each level.
All levels share the same underlying data layout, but we use different qo_indptr
/ kv_page_indptr
arrays
so that we can view them in different ways.
FlashInfer APIs¶
FlashInfer provides flashinfer.cascade.MultiLevelCascadeAttentionWrapper
to compute
the cascade attention.
FAQ¶
- How do FlashInfer manages KV-Cache?
FlashInfer itself is not responsible for managing the page-table (pop and allocate new pages, etc.) and we leave the strategy to the user: different serving engine might have different strategies to manage the page-table. FlashInfer is only responsible for computing the attention between queries and keys/values stored in KV-Cache.