flashinfer.xqa.xqa_mla¶
- flashinfer.xqa.xqa_mla(q: Tensor, k_cache: Tensor, v_cache: Tensor, page_table: Tensor, seq_lens: Tensor, output: Tensor, workspace_buffer: Tensor, semaphores: Tensor, page_size: int, q_scale: float | Tensor = 1.0, kv_scale: float | Tensor = 1.0, sm_count: int | None = None, enable_pdl: bool | None = None) None¶
Apply attention with paged KV cache using XQA MLA (Multi-Head Latent Attention) kernel. :param q: Query tensor with shape
[batch_size, beam_width, num_q_heads, head_dim].Data type should be torch.float8_e4m3fn. Now only beam_width 1 is supported.
- Parameters:
k_cache (torch.Tensor) – Paged K cache tensor with shape
[total_num_cache_heads, head_dim]. Data type should be torch.float8_e4m3fnv_cache (torch.Tensor) – Paged V cache tensor with shape
[total_num_cache_heads, head_dim]. Data type should be torch.float8_e4m3fnpage_table (torch.Tensor) – Page table tensor with shape
batch_size, nb_pages_per_seq. Data type should be torch.int32. K and V share the same table.seq_lens (torch.Tensor) – Sequence lengths tensor with shape
[batch_size, beam_width]. Data type should be torch.uint32.output (torch.Tensor) – Output tensor with shape
[batch_size, beam_width, num_q_heads, head_dim]. Data type should be torch.bfloat16. This tensor will be modified in-place.workspace_buffer (torch.Tensor) – Workspace buffer for temporary computations. Data type should be torch.uint8.
semaphores (torch.Tensor) – Semaphore buffer for synchronization. Data type should be torch.uint32.
page_size (int) – Size of each page in the paged KV cache. Must be one of [16, 32, 64, 128].
q_scale (Union[float, torch.Tensor], default=1.0) – Scale factor for query tensor.
kv_scale (Union[float, torch.Tensor], default=1.0) – Scale factor for KV cache.
sm_count (Optional[int], default=None) – Number of streaming multiprocessors to use. If None, will be inferred from the device.
enable_pdl (Optional[bool], default=None) – Whether to enable PDL (Persistent Data Loader) optimization. If None, will be set to True if hardware supports it.
Note
The function automatically infers several parameters from tensor shapes: - batch_size from q.shape[0] - head_dim from q.shape[-1] - head_group_ratio is fixed to 128 for MLA - max_seq_len from page_table.shape[-1] * page_size