flashinfer.decode.xqa_batch_decode_with_kv_cache¶
- flashinfer.decode.xqa_batch_decode_with_kv_cache(query: Tensor, kv_cache: Tensor | Tuple[Tensor, Tensor], workspace_buffer: Tensor, block_tables: Tensor, seq_lens: Tensor, max_seq_len: int, bmm1_scale: float | Tensor = 1.0, bmm2_scale: float | Tensor = 1.0, window_left: int = -1, out: Tensor | None = None, sinks: Tensor | None = None, kv_layout: str = 'NHD', enable_pdl: bool = None, q_len_per_req: int | None = 1, o_scale: float | None = 1.0, mask: Tensor | None = None, kv_cache_sf: Tensor | Tuple[Tensor, Tensor] | None = None) Tensor¶
- Parameters:
query (torch.Tensor) – query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = batch_size * q_len_per_request
kv_cache (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) – If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, page_size, num_kv_heads, head_dim] if
kv_layoutisNHD, or [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] ifkv_layoutisHND. If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, page_size, num_kv_heads, head_dim] ifkv_layoutisNHD, or [num_pages, num_kv_heads, page_size, head_dim] ifkv_layoutisHND.workspace_buffer (torch.Tensor. Must be initialized to 0 for its first use.) – workspace
block_tables (torch.Tensor) – page_table of kv cache, [batch_size, num_pages]
seq_lens (torch.Tensor) – A uint32 1D tensor indicating the kv sequence length of each prompt. shape:
[batch_size]max_seq_len (int) – max sequence length for kv_cache
bmm1_scale (Union[float, torch.Tensor]) – fused scale for bmm1 input.
bmm2_scale (Union[float, torch.Tensor]) – fused scale for bmm2 input.
window_left (int = -1) – The left (inclusive) window size for the attention window, when set to
-1, the window size will be set to the full length of the sequence. Defaults to-1.out (Optional[torch.Tensor] = None) – output tensor, if not provided, will be allocated with
query.dtype.sinks (Optional[torch.Tensor] = None) – additional value per head in the denominator of the softmax.
kv_layout (str) – The layout of the kv cache. Can be either
NHDorHND. Defaults toNHD.enable_pdl (bool) – Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization Only supported for >= sm90, and currently only for FA2, CUDA core, and trtllm-gen decode.
q_len_per_req (Optional[int] = 1) – Number of query tokens per request (i.e. speculative-decoding / MTP depth).
queryis expected to havebatch_size * q_len_per_reqrows along its leading dimension. Defaults to1.o_scale (Optional[float] = 1.0) – output scale factor for fp8 output.
mask (Optional[torch.Tensor] = None) – causal attention mask for xqa speculative decoding.
kv_cache_sf (Optional[torch.Tensor] = None) – KV cache scaling factors. Must provide when NVFP4 KV cache is used.
- Returns:
out – output torch.Tensor.
- Return type:
torch.Tensor