flashinfer.sparse#

Kernels for block sparse flashattention.

class flashinfer.sparse.BlockSparseAttentionWrapper(float_workspace_buffer: torch.Tensor)#

Wrapper class for attention computation with a block-sparse matrix as attention mask. The definition of block sparse matrix can be found at bsr_matrix in SciPy.

This API supports any block size (R, C).

Example

>>> import torch
>>> import flashinfer
>>> num_qo_heads = 32
>>> num_kv_heads = 8
>>> head_dim = 128
>>> # allocate 128MB workspace buffer
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
>>> bsr_wrapper = flashinfer.BlockSparseAttentionWrapper(workspace_buffer)
>>> # sparse mask: [[0, 0, 1], [1, 0, 1], [0, 1, 1]]
>>> M = 3
>>> N = 3
>>> indptr = torch.tensor([0, 1, 3, 5], dtype=torch.int32, device="cuda:0")
>>> indices = torch.tensor([2, 0, 2, 1, 2], dtype=torch.int32, device="cuda:0")
>>> bsr_wrapper.plan(
...     indptr,
...     indices,
...     M,
...     N,
...     1, # R(block_rows)=1
...     1, # C(block_columns)=1
...     num_qo_heads,
...     num_kv_heads,
...     head_dim,
... )
>>> q = torch.randn((M, num_qo_heads, head_dim), dtype=torch.float16, device="cuda:0")
>>> k = torch.randn((N, num_kv_heads, head_dim), dtype=torch.float16, device="cuda:0")
>>> v = torch.randn((N, num_kv_heads, head_dim), dtype=torch.float16, device="cuda:0")
>>> o = bsr_wrapper.run(q, k, v)
>>> # use dense implementation with attention mask for comparison
>>> mask = torch.tensor([[0, 0, 1], [1, 0, 1], [0, 1, 1]], dtype=torch.bool, device="cuda:0")
>>> o_ref = flashinfer.single_prefill_with_kv_cache(q, k, v, custom_mask=mask)
>>> torch.allclose(o, o_ref)
True
__init__(float_workspace_buffer: torch.Tensor) None#

Constructs of BlockSparseAttentionWrapper.

Parameters:

float_workspace_buffer (torch.Tensor) – The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.

plan(indptr: torch.Tensor, indices: torch.Tensor, M: int, N: int, R: int, C: int, num_qo_heads: int, num_kv_heads: int, head_dim: int, mask: torch.Tensor | None = None, packed_mask: torch.Tensor | None = None, pos_encoding_mode: str = 'NONE', allow_fp16_qk_reduction: bool = False, logits_soft_cap: float | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, q_data_type: str = 'float16') None#

Create auxiliary data structures for block sparse attention.

Parameters:
  • indptr (torch.Tensor) – The block index pointer of the block-sparse matrix on row dimension, shape (MB + 1,), where MB is the number of blocks in the row dimension.

  • indices (torch.Tensor) – The block indices of the block-sparse matrix on column dimension, shape (nnz,), where nnz is the number of non-zero blocks. The elements in indices array should be less then NB: the number of blocks in the column dimension.

  • M (int) – The number of rows of the block-sparse matrix, MB = ceil_div(M, R).

  • N (int) – The number of columns of the block-sparse matrix, NB = N // C, N should be divisible by C.

  • R (int) – The number of rows in each block.

  • C (int) – The number of columns in each block.

  • num_qo_heads (int) – The number of heads in the query/output tensor.

  • num_kv_heads (int) – The number of heads in the key/value tensor.

  • head_dim (int) – The dimension of each head.

  • mask (torch.Tensor, optional) – The mask tensor with shape (nnz, R, C,), where nnz is the number of non-zero blocks. If every block is full, then we don’t need to provide the mask tensor.

  • packed_mask (torch.Tensor, optional) – The 1D packed mask tensor, if provided, the custom_mask will be ignored. The packed mask tensor is generated by flashinfer.quantization.packbits().

  • pos_encoding_mode (str, optional) – The position encoding applied inside attention kernels, could be NONE/ROPE_LLAMA (LLAMA style rotary embedding) /ALIBI. Default is NONE.

  • allow_fp16_qk_reduction (bool) – Whether to use f16 for qk reduction (faster at the cost of slight precision loss).

  • logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to 0. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.

  • sm_scale (Optional[float]) – The scale used in softmax, if not provided, will be set to 1.0 / sqrt(head_dim).

  • rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to 1.0.

  • rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to 1e4.

q_data_typestr, optional

The data type of the query tensor.

The plan() method should be called before any run() or run_return_lse() calls, auxiliary data structures will be created during this call and cached for multiple kernel runs.

The num_qo_heads must be a multiple of num_kv_heads. If num_qo_heads is not equal to num_kv_heads, the function will use grouped query attention.

reset_workspace_buffer(float_workspace_buffer: torch.Tensor, int_workspace_buffer) None#

Reset the workspace buffer.

Parameters:
  • float_workspace_buffer (torch.Tensor) – The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors.

  • int_workspace_buffer (torch.Tensor) – The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors.

run(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) torch.Tensor#

Compute block-sparse attention between Q/K/V tensors.

Parameters:
  • q (torch.Tensor) – The query tensor with shape (M, num_qo_heads, head_dim).

  • k (torch.Tensor) – The key tensor with shape (N, num_kv_heads, head_dim).

  • v (torch.Tensor) – The value tensor with shape (N, num_kv_heads, head_dim).

Returns:

The attention output, shape: [qo_indptr[-1], num_qo_heads, head_dim].

Return type:

torch.Tensor