flashinfer.xqa.xqa

flashinfer.xqa.xqa(q: Tensor, k_cache: Tensor, v_cache: Tensor, page_table: Tensor, seq_lens: Tensor, output: Tensor, workspace_buffer: Tensor, semaphores: Tensor, num_kv_heads: int, page_size: int, sinks: Tensor | None = None, q_scale: float = 1.0, kv_scale: Tensor | None = None, sliding_win_size: int = 0, sm_count: int | None = None) None

Apply attention with paged KV cache using XQA kernel. :param q: Query tensor with shape [batch_size, beam_width, num_q_heads, head_dim].

Data type should be torch.float16 or torch.bfloat16. Now only beam_width 1 is supported.

Parameters:
  • k_cache (torch.Tensor) – Paged K cache tensor with shape [total_num_cache_heads, head_dim]. Data type should match query tensor or be torch.float8_e4m3fn, in which case xqa will run fp8 calculation. Should be the same data type as v_cache.

  • v_cache (torch.Tensor) – Paged V cache tensor with shape [total_num_cache_heads, head_dim]. Data type should match query tensor or be torch.float8_e4m3fn, in which case xqa will run fp8 calculation. Should be the same data type as k_cache.

  • page_table (torch.Tensor) – Page table tensor with shape batch_size, nb_pages_per_seq. Data type should be torch.int32. K and V share the same table.

  • seq_lens (torch.Tensor) – Sequence lengths tensor with shape [batch_size, beam_width]. Data type should be torch.uint32.

  • output (torch.Tensor) – Output tensor with shape [batch_size, beam_width, num_q_heads, head_dim]. Data type should match query tensor. This tensor will be modified in-place.

  • workspace_buffer (torch.Tensor) – Workspace buffer for temporary computations. Data type should be torch.uint8.

  • semaphores (torch.Tensor) – Semaphore buffer for synchronization. Data type should be torch.uint32.

  • num_kv_heads (int) – Number of key-value heads in the attention mechanism.

  • page_size (int) – Size of each page in the paged KV cache. Must be one of [16, 32, 64, 128].

  • sinks (Optional[torch.Tensor], default=None) – Attention sink values with shape [num_kv_heads, head_group_ratio]. Data type should be torch.float32. If None, no attention sinks are used.

  • q_scale (float, default=1.0) – Scale factor for query tensor.

  • kv_scale (Optional[torch.Tensor], default=None) – Scale factor for KV cache with shape [1]. Data type should be torch.float32. If None, defaults to 1.0.

  • sliding_win_size (int, default=0) – Sliding window size for attention. If 0, no sliding window is used.

  • sm_count (Optional[int], default=None) – Number of streaming multiprocessors to use. If None, will be inferred from the device.

Note

The function automatically infers several parameters from tensor shapes: - batch_size from q.shape[0] - num_q_heads from q.shape[2] - head_dim from q.shape[-1] - input_dtype from q.dtype - kv_cache_dtype from k.dtype - head_group_ratio from num_q_heads // num_kv_heads - max_seq_len from page_table.shape[-1] * page_size