flashinfer.xqa.xqa¶
- flashinfer.xqa.xqa(q: Tensor, k_cache: Tensor, v_cache: Tensor, page_table: Tensor, seq_lens: Tensor, output: Tensor, workspace_buffer: Tensor, semaphores: Tensor, num_kv_heads: int, page_size: int, sinks: Tensor | None = None, q_scale: float = 1.0, kv_scale: Tensor | None = None, sliding_win_size: int = 0, sm_count: int | None = None) None¶
Apply attention with paged KV cache using XQA kernel. :param q: Query tensor with shape
[batch_size, beam_width, num_q_heads, head_dim].Data type should be torch.float16 or torch.bfloat16. Now only beam_width 1 is supported.
- Parameters:
k_cache (torch.Tensor) – Paged K cache tensor with shape
[total_num_cache_heads, head_dim]. Data type should match query tensor or be torch.float8_e4m3fn, in which case xqa will run fp8 calculation. Should be the same data type as v_cache.v_cache (torch.Tensor) – Paged V cache tensor with shape
[total_num_cache_heads, head_dim]. Data type should match query tensor or be torch.float8_e4m3fn, in which case xqa will run fp8 calculation. Should be the same data type as k_cache.page_table (torch.Tensor) – Page table tensor with shape
batch_size, nb_pages_per_seq. Data type should be torch.int32. K and V share the same table.seq_lens (torch.Tensor) – Sequence lengths tensor with shape
[batch_size, beam_width]. Data type should be torch.uint32.output (torch.Tensor) – Output tensor with shape
[batch_size, beam_width, num_q_heads, head_dim]. Data type should match query tensor. This tensor will be modified in-place.workspace_buffer (torch.Tensor) – Workspace buffer for temporary computations. Data type should be torch.uint8.
semaphores (torch.Tensor) – Semaphore buffer for synchronization. Data type should be torch.uint32.
num_kv_heads (int) – Number of key-value heads in the attention mechanism.
page_size (int) – Size of each page in the paged KV cache. Must be one of [16, 32, 64, 128].
sinks (Optional[torch.Tensor], default=None) – Attention sink values with shape
[num_kv_heads, head_group_ratio]. Data type should be torch.float32. If None, no attention sinks are used.q_scale (float, default=1.0) – Scale factor for query tensor.
kv_scale (Optional[torch.Tensor], default=None) – Scale factor for KV cache with shape
[1]. Data type should be torch.float32. If None, defaults to 1.0.sliding_win_size (int, default=0) – Sliding window size for attention. If 0, no sliding window is used.
sm_count (Optional[int], default=None) – Number of streaming multiprocessors to use. If None, will be inferred from the device.
Note
The function automatically infers several parameters from tensor shapes: - batch_size from q.shape[0] - num_q_heads from q.shape[2] - head_dim from q.shape[-1] - input_dtype from q.dtype - kv_cache_dtype from k.dtype - head_group_ratio from num_q_heads // num_kv_heads - max_seq_len from page_table.shape[-1] * page_size