flashinfer.xqa.xqa_mla

flashinfer.xqa.xqa_mla(q: Tensor, k_cache: Tensor, v_cache: Tensor, page_table: Tensor, seq_lens: Tensor, output: Tensor, workspace_buffer: Tensor, semaphores: Tensor, page_size: int, q_scale: float | Tensor = 1.0, kv_scale: float | Tensor = 1.0, sm_count: int | None = None, enable_pdl: bool | None = None) None

Apply attention with paged KV cache using XQA MLA (Multi-Head Latent Attention) kernel. :param q: Query tensor with shape [batch_size, beam_width, num_q_heads, head_dim].

Data type should be torch.float8_e4m3fn. Now only beam_width 1 is supported.

Parameters:
  • k_cache (torch.Tensor) – Paged K cache tensor with shape [total_num_cache_heads, head_dim]. Data type should be torch.float8_e4m3fn

  • v_cache (torch.Tensor) – Paged V cache tensor with shape [total_num_cache_heads, head_dim]. Data type should be torch.float8_e4m3fn

  • page_table (torch.Tensor) – Page table tensor with shape batch_size, nb_pages_per_seq. Data type should be torch.int32. K and V share the same table.

  • seq_lens (torch.Tensor) – Sequence lengths tensor with shape [batch_size, beam_width]. Data type should be torch.uint32.

  • output (torch.Tensor) – Output tensor with shape [batch_size, beam_width, num_q_heads, head_dim]. Data type should be torch.bfloat16. This tensor will be modified in-place.

  • workspace_buffer (torch.Tensor) – Workspace buffer for temporary computations. Data type should be torch.uint8.

  • semaphores (torch.Tensor) – Semaphore buffer for synchronization. Data type should be torch.uint32.

  • page_size (int) – Size of each page in the paged KV cache. Must be one of [16, 32, 64, 128].

  • q_scale (Union[float, torch.Tensor], default=1.0) – Scale factor for query tensor.

  • kv_scale (Union[float, torch.Tensor], default=1.0) – Scale factor for KV cache.

  • sm_count (Optional[int], default=None) – Number of streaming multiprocessors to use. If None, will be inferred from the device.

  • enable_pdl (Optional[bool], default=None) – Whether to enable PDL (Persistent Data Loader) optimization. If None, will be set to True if hardware supports it.

Note

The function automatically infers several parameters from tensor shapes: - batch_size from q.shape[0] - head_dim from q.shape[-1] - head_group_ratio is fixed to 128 for MLA - max_seq_len from page_table.shape[-1] * page_size