flashinfer.decode.xqa_batch_decode_with_kv_cache

flashinfer.decode.xqa_batch_decode_with_kv_cache(query: Tensor, kv_cache: Tensor | Tuple[Tensor, Tensor], workspace_buffer: Tensor, block_tables: Tensor, seq_lens: Tensor, max_seq_len: int, bmm1_scale: float | Tensor = 1.0, bmm2_scale: float | Tensor = 1.0, window_left: int = -1, out: Tensor | None = None, sinks: Tensor | None = None, kv_layout: str = 'NHD', enable_pdl: bool = None, q_len_per_req: int | None = 1, o_scale: float | None = 1.0, mask: Tensor | None = None, kv_cache_sf: Tensor | Tuple[Tensor, Tensor] | None = None) Tensor
Parameters:
  • query (torch.Tensor) – query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = batch_size * q_len_per_request

  • kv_cache (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) – If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, page_size, num_kv_heads, head_dim] if kv_layout is NHD, or [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] if kv_layout is HND. If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, page_size, num_kv_heads, head_dim] if kv_layout is NHD, or [num_pages, num_kv_heads, page_size, head_dim] if kv_layout is HND.

  • workspace_buffer (torch.Tensor. Must be initialized to 0 for its first use.) – workspace

  • block_tables (torch.Tensor) – page_table of kv cache, [batch_size, num_pages]

  • seq_lens (torch.Tensor) – A uint32 1D tensor indicating the kv sequence length of each prompt. shape: [batch_size]

  • max_seq_len (int) – max sequence length for kv_cache

  • bmm1_scale (Union[float, torch.Tensor]) – fused scale for bmm1 input.

  • bmm2_scale (Union[float, torch.Tensor]) – fused scale for bmm2 input.

  • window_left (int = -1) – The left (inclusive) window size for the attention window, when set to -1, the window size will be set to the full length of the sequence. Defaults to -1.

  • out (Optional[torch.Tensor] = None) – output tensor, if not provided, will be allocated with query.dtype.

  • sinks (Optional[torch.Tensor] = None) – additional value per head in the denominator of the softmax.

  • kv_layout (str) – The layout of the kv cache. Can be either NHD or HND. Defaults to NHD.

  • enable_pdl (bool) – Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization Only supported for >= sm90, and currently only for FA2, CUDA core, and trtllm-gen decode.

  • q_len_per_req (Optional[int] = 1) – Number of query tokens per request (i.e. speculative-decoding / MTP depth). query is expected to have batch_size * q_len_per_req rows along its leading dimension. Defaults to 1.

  • o_scale (Optional[float] = 1.0) – output scale factor for fp8 output.

  • mask (Optional[torch.Tensor] = None) – causal attention mask for xqa speculative decoding.

  • kv_cache_sf (Optional[torch.Tensor] = None) – KV cache scaling factors. Must provide when NVFP4 KV cache is used.

Returns:

out – output torch.Tensor.

Return type:

torch.Tensor