flashinfer.sparse#

Kernels for block sparse flashattention.

class flashinfer.sparse.BlockSparseAttentionWrapper(float_workspace_buffer: torch.Tensor)#

Wrapper class for attention computation with a block-sparse matrix as attention mask. The definition of block sparse matrix can be found at bsr_matrix in SciPy.

This API supports any block size (R, C).

Example

>>> import torch
>>> import flashinfer
>>> num_qo_heads = 32
>>> num_kv_heads = 8
>>> head_dim = 128
>>> # allocate 128MB workspace buffer
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
>>> bsr_wrapper = flashinfer.BlockSparseAttentionWrapper(workspace_buffer)
>>> # sparse mask: [[0, 0, 1], [1, 0, 1], [0, 1, 1]]
>>> M = 3
>>> N = 3
>>> indptr = torch.tensor([0, 1, 3, 5], dtype=torch.int32, device="cuda:0")
>>> indices = torch.tensor([2, 0, 2, 1, 2], dtype=torch.int32, device="cuda:0")
>>> bsr_wrapper.plan(
...     indptr,
...     indices,
...     M,
...     N,
...     1, # R(block_rows)=1
...     1, # C(block_columns)=1
...     num_qo_heads,
...     num_kv_heads,
...     head_dim,
... )
>>> q = torch.randn((M, num_qo_heads, head_dim), dtype=torch.float16, device="cuda:0")
>>> k = torch.randn((N, num_kv_heads, head_dim), dtype=torch.float16, device="cuda:0")
>>> v = torch.randn((N, num_kv_heads, head_dim), dtype=torch.float16, device="cuda:0")
>>> o = bsr_wrapper.run(q, k, v)
>>> # use dense implementation with attention mask for comparison
>>> mask = torch.tensor([[0, 0, 1], [1, 0, 1], [0, 1, 1]], dtype=torch.bool, device="cuda:0")
>>> o_ref = flashinfer.single_prefill_with_kv_cache(q, k, v, custom_mask=mask)
>>> torch.allclose(o, o_ref)
True
__init__(float_workspace_buffer: torch.Tensor) None#

Constructs of BlockSparseAttentionWrapper.

Parameters:

float_workspace_buffer (torch.Tensor) – The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors.

plan(indptr: torch.Tensor, indices: torch.Tensor, M: int, N: int, R: int, C: int, num_qo_heads: int, num_kv_heads: int, head_dim: int, mask: torch.Tensor | None = None, packed_mask: torch.Tensor | None = None, pos_encoding_mode: str = 'NONE', allow_fp16_qk_reduction: bool = False, logits_soft_cap: float | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, q_data_type: str | torch.dtype = 'float16', kv_data_type: str | torch.dtype | None = None, non_blocking: bool = False) None#

Create auxiliary data structures for block sparse attention.

Parameters:
  • indptr (torch.Tensor) – The block index pointer of the block-sparse matrix on row dimension, shape (MB + 1,), where MB is the number of blocks in the row dimension.

  • indices (torch.Tensor) – The block indices of the block-sparse matrix on column dimension, shape (nnz,), where nnz is the number of non-zero blocks. The elements in indices array should be less then NB: the number of blocks in the column dimension.

  • M (int) – The number of rows of the block-sparse matrix, MB = ceil_div(M, R).

  • N (int) – The number of columns of the block-sparse matrix, NB = N // C, N should be divisible by C.

  • R (int) – The number of rows in each block.

  • C (int) – The number of columns in each block.

  • num_qo_heads (int) – The number of heads in the query/output tensor.

  • num_kv_heads (int) – The number of heads in the key/value tensor.

  • head_dim (int) – The dimension of each head.

  • mask (torch.Tensor, optional) – The mask tensor with shape (nnz, R, C,), where nnz is the number of non-zero blocks. If every block is full, then we don’t need to provide the mask tensor.

  • packed_mask (torch.Tensor, optional) – The 1D packed mask tensor, if provided, the custom_mask will be ignored. The packed mask tensor is generated by flashinfer.quantization.packbits().

  • pos_encoding_mode (str, optional) – The position encoding applied inside attention kernels, could be NONE/ROPE_LLAMA (LLAMA style rotary embedding) /ALIBI. Default is NONE.

  • allow_fp16_qk_reduction (bool) – Whether to use f16 for qk reduction (faster at the cost of slight precision loss).

  • logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to 0. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.

  • sm_scale (Optional[float]) – The scale used in softmax, if not provided, will be set to 1.0 / sqrt(head_dim).

  • rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to 1.0.

  • rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to 1e4.

  • q_data_type (str, optional) – The data type of the query tensor.

  • kv_data_type (Optional[Union[str, torch.dtype]]) – The data type of the key/value tensor. If None, will be set to q_data_type.

  • non_blocking (bool) – Whether to copy the input tensors to the device asynchronously, defaults to False. If True, user should synchronize before calling run() or cuda graph replay.

The plan() method should be called before any run() or run_return_lse() calls, auxiliary data structures will be created during this call and cached for multiple kernel runs.

The num_qo_heads must be a multiple of num_kv_heads. If num_qo_heads is not equal to num_kv_heads, the function will use grouped query attention.

reset_workspace_buffer(float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor) None#

Reset the workspace buffer.

Parameters:
  • float_workspace_buffer (torch.Tensor) – The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors.

  • int_workspace_buffer (torch.Tensor) – The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors.

run(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, return_lse: bool = False) torch.Tensor | Tuple[torch.Tensor, torch.Tensor]#

Compute block-sparse attention between Q/K/V tensors.

Parameters:
  • q (torch.Tensor) – The query tensor with shape (M, num_qo_heads, head_dim).

  • k (torch.Tensor) – The key tensor with shape (N, num_kv_heads, head_dim).

  • v (torch.Tensor) – The value tensor with shape (N, num_kv_heads, head_dim).

  • return_lse (bool) – Whether to return the logsumexp of attention output

Returns:

If return_lse is False, the attention output, shape: [M, num_qo_heads, head_dim]. If return_lse is True, a tuple of two tensors:

  • The attention output, shape: [M, num_qo_heads, head_dim].

  • The logsumexp of attention output, shape: [M, num_qo_heads].

Return type:

Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]