Attention States and Recursive Attention

FlashInfer introduces the concept of attention states, which fully characterizes the attention between a query and a set of key/value pairs. We further defines a merge operator on the attention states. This merge operator facilitates the computation of complete attention by allowing the recursive merging of attention states.

Suppose we define \(s_i = \mathbf{q}\mathbf{k}_i^T\) as the pre-softmax attention score between the query \(\mathbf{q}\) and the key \(\mathbf{k}_i\). The Self-Attention score on index \(i\) can be generalized to index set \(I\):

\[s(I)=\log\left(\sum_{i\in I}\exp\left(s_i\right)\right)\]

We can also generalize the value on index \(i\) to index set \(I\):

\[\mathbf{v}(I) = \sum_{i\in I}\textrm{softmax}(s_i) \mathbf{v}_i = \frac{\sum_{i\in I}\exp\left(s_i\right)\mathbf{v}_i}{\exp(s(I))}\]

The \(softmax\) function is restricted to the index set \(I\). Note that \(\mathbf{v}(\{1,2,\cdots, n\})\) is the self-attention output of the entire sequence. The attention state of the index set \(i\) can be defined as a tuple \((s(I), \mathbf{v}(I))\), then we can define a binary merge operator \(\oplus\) of two attention states as ((in practice we will minus $s$ with maximum value to guarantee numerical stability and here we omit them for simplicity):

\[\begin{split}\begin{bmatrix}\mathbf{v}(I\cup J)\\s(I\cup J)\end{bmatrix}=\begin{bmatrix}\mathbf{v}(I)\\s(I)\end{bmatrix}\oplus\begin{bmatrix}\mathbf{v}(J)\\s(J)\end{bmatrix}=\begin{bmatrix} \frac{\mathbf{v}(I)\exp(s(I)) + \mathbf{v}(J)\exp(s(J))}{\exp(s(I)) + \exp(s(J))} \\ \log(\exp(s(I)) + \exp(s(J))) \end{bmatrix}\end{split}\]

the merge operator can be generalized to any number of attention state inputs:

\[\begin{split}\begin{bmatrix}\mathbf{v}(\bigcup_{i=1}^{n}I_i) \\ s(\bigcup_{i=1}^{n}I_i) \end{bmatrix} = \bigoplus_{i=1}^{n}\begin{bmatrix}\mathbf{v}(I_i) \\ s(I_i)\end{bmatrix} = \begin{bmatrix} \sum_{i=1}^{n} \textrm{softmax}(s(I_i))\mathbf{v}(I_i) \\ \log(\sum_{i=1}^{n} \exp (s(I_i))) \end{bmatrix}\end{split}\]

The above n-ary merge operator is consistent with the binary merge operator, and we can prove the operator is communicative and associative. There are different ways to get the attention state of the entire sequence by merging the attention states of index subsets, and the final outcome is mathematically equivalent:

Recurisve Attention

Note

The generalized score \(s\) is also known as log-sum-exp (lse for short).

Applications

Note that \(\oplus\) operator is commutative and associative, which means we can safely offload the self-attention computation on a subset of KV to different devices and merge the results in any order.

There are several interesting applications of this recursive form of self-attention in FlashInfer so far:

Shared-Prefix Batch Decoding

Many LLM applications involves batch decoding with the shared long prompt, FlashInfer decomposes attention on the entire KV-Cache to shared prefix attention and unique suffixes attention. This decomposition enables the offloading of these components to different kernel implementations, resulting in a remarkable 30x acceleration in scenarios with long context and large batch-size. Such decomposition accelerates the operator by 30 times in long context setting. Check our blog post on more details about this application, and Cascade Attention on how to use this feature in FlashInfer.

KV Sequence Parallelism

For long context LLM inference/serving, the batch size and number of heads per GPU is limited by the GPU memory, and the default parallelism strategy cannot use all SMs in GPUs, which results in suboptimal performance. Inspired by Split-K trick in GEMM optimizations. FlashInfer partitions the KV sequence dimension and dispatches the attention computations to different thread-blocks and merge them in a second pass. This same idea was also proposed in Flash-Decoding, you can check their great blog post for visualizations and more details.