flashinfer.sparse#
Kernels for block sparse flashattention.
- class flashinfer.sparse.BlockSparseAttentionWrapper(float_workspace_buffer: torch.Tensor)#
Wrapper class for attention computation with a block-sparse matrix as attention mask. The definition of block sparse matrix can be found at bsr_matrix in SciPy.
This API supports any block size
(R, C)
.Example
>>> import torch >>> import flashinfer >>> num_qo_heads = 32 >>> num_kv_heads = 8 >>> head_dim = 128 >>> # allocate 128MB workspace buffer >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") >>> bsr_wrapper = flashinfer.BlockSparseAttentionWrapper(workspace_buffer) >>> # sparse mask: [[0, 0, 1], [1, 0, 1], [0, 1, 1]] >>> M = 3 >>> N = 3 >>> indptr = torch.tensor([0, 1, 3, 5], dtype=torch.int32, device="cuda:0") >>> indices = torch.tensor([2, 0, 2, 1, 2], dtype=torch.int32, device="cuda:0") >>> bsr_wrapper.plan( ... indptr, ... indices, ... M, ... N, ... 1, # R(block_rows)=1 ... 1, # C(block_columns)=1 ... num_qo_heads, ... num_kv_heads, ... head_dim, ... ) >>> q = torch.randn((M, num_qo_heads, head_dim), dtype=torch.float16, device="cuda:0") >>> k = torch.randn((N, num_kv_heads, head_dim), dtype=torch.float16, device="cuda:0") >>> v = torch.randn((N, num_kv_heads, head_dim), dtype=torch.float16, device="cuda:0") >>> o = bsr_wrapper.run(q, k, v) >>> # use dense implementation with attention mask for comparison >>> mask = torch.tensor([[0, 0, 1], [1, 0, 1], [0, 1, 1]], dtype=torch.bool, device="cuda:0") >>> o_ref = flashinfer.single_prefill_with_kv_cache(q, k, v, custom_mask=mask) >>> torch.allclose(o, o_ref) True
- __init__(float_workspace_buffer: torch.Tensor) None #
Constructs of
BlockSparseAttentionWrapper
.- Parameters:
float_workspace_buffer (torch.Tensor) – The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.
- plan(indptr: torch.Tensor, indices: torch.Tensor, M: int, N: int, R: int, C: int, num_qo_heads: int, num_kv_heads: int, head_dim: int, mask: torch.Tensor | None = None, packed_mask: torch.Tensor | None = None, pos_encoding_mode: str = 'NONE', allow_fp16_qk_reduction: bool = False, logits_soft_cap: float | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, q_data_type: str | torch.dtype = 'float16', kv_data_type: str | torch.dtype | None = None, non_blocking: bool = False) None #
Create auxiliary data structures for block sparse attention.
- Parameters:
indptr (torch.Tensor) – The block index pointer of the block-sparse matrix on row dimension, shape
(MB + 1,)
, whereMB
is the number of blocks in the row dimension.indices (torch.Tensor) – The block indices of the block-sparse matrix on column dimension, shape
(nnz,)
, wherennz
is the number of non-zero blocks. The elements inindices
array should be less thenNB
: the number of blocks in the column dimension.M (int) – The number of rows of the block-sparse matrix,
MB = ceil_div(M, R)
.N (int) – The number of columns of the block-sparse matrix,
NB = N // C
,N
should be divisible byC
.R (int) – The number of rows in each block.
C (int) – The number of columns in each block.
num_qo_heads (int) – The number of heads in the query/output tensor.
num_kv_heads (int) – The number of heads in the key/value tensor.
head_dim (int) – The dimension of each head.
mask (torch.Tensor, optional) – The mask tensor with shape
(nnz, R, C,)
, where nnz is the number of non-zero blocks. If every block is full, then we don’t need to provide the mask tensor.packed_mask (torch.Tensor, optional) – The 1D packed mask tensor, if provided, the
custom_mask
will be ignored. The packed mask tensor is generated byflashinfer.quantization.packbits()
.pos_encoding_mode (str, optional) – The position encoding applied inside attention kernels, could be
NONE
/ROPE_LLAMA
(LLAMA style rotary embedding) /ALIBI
. Default isNONE
.allow_fp16_qk_reduction (bool) – Whether to use f16 for qk reduction (faster at the cost of slight precision loss).
logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to
0
. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.sm_scale (Optional[float]) – The scale used in softmax, if not provided, will be set to
1.0 / sqrt(head_dim)
.rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to
1.0
.rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to
1e4
.q_data_type (str, optional) – The data type of the query tensor.
kv_data_type (Optional[Union[str, torch.dtype]]) – The data type of the key/value tensor. If None, will be set to
q_data_type
.non_blocking (bool) – Whether to copy the input tensors to the device asynchronously, defaults to
False
. IfTrue
, user should synchronize before callingrun()
or cuda graph replay.
The
plan()
method should be called before anyrun()
orrun_return_lse()
calls, auxiliary data structures will be created during this call and cached for multiple kernel runs.The
num_qo_heads
must be a multiple ofnum_kv_heads
. Ifnum_qo_heads
is not equal tonum_kv_heads
, the function will use grouped query attention.
- reset_workspace_buffer(float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor) None #
Reset the workspace buffer.
- Parameters:
float_workspace_buffer (torch.Tensor) – The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors.
int_workspace_buffer (torch.Tensor) – The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors.
- run(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, return_lse: bool = False) torch.Tensor | Tuple[torch.Tensor, torch.Tensor] #
Compute block-sparse attention between Q/K/V tensors.
- Parameters:
q (torch.Tensor) – The query tensor with shape
(M, num_qo_heads, head_dim)
.k (torch.Tensor) – The key tensor with shape
(N, num_kv_heads, head_dim)
.v (torch.Tensor) – The value tensor with shape
(N, num_kv_heads, head_dim)
.return_lse (bool) – Whether to return the logsumexp of attention output
- Returns:
If
return_lse
isFalse
, the attention output, shape:[M, num_qo_heads, head_dim]
. Ifreturn_lse
isTrue
, a tuple of two tensors:The attention output, shape:
[M, num_qo_heads, head_dim]
.The logsumexp of attention output, shape:
[M, num_qo_heads]
.
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]