flashinfer.decode.trtllm_batch_decode_with_kv_cache¶
- flashinfer.decode.trtllm_batch_decode_with_kv_cache(query: Tensor, kv_cache: Tensor | Tuple[Tensor, Tensor], workspace_buffer: Tensor, block_tables: Tensor, seq_lens: Tensor, max_seq_len: int, bmm1_scale: float | Tensor = 1.0, bmm2_scale: float | Tensor = 1.0, window_left: int = -1, out: Tensor | FP4Tensor | None = None, out_dtype: str | dtype | None = None, o_sf_scale: float | None = None, o_sf_vec_size: int | None = None, sinks: List[Tensor] | None = None, kv_layout: str = 'HND', enable_pdl: bool | None = None, backend: str = 'auto', q_len_per_req: int | None = 1, o_scale: float | None = 1.0, mask: Tensor | None = None, max_q_len: int | None = None, cum_seq_lens_q: Tensor | None = None, kv_block_scales: Tensor | Tuple[Tensor, Tensor] | None = None, skip_softmax_threshold_scale_factor: float | None = None, kv_cache_sf: Tuple[Tensor, Tensor] | None = None, uses_shared_paged_kv_idx: bool = True) Tensor | FP4Tensor¶
- Parameters:
query (torch.Tensor) – query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = total query tokens in the batch.
kv_cache (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) –
If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] if
kv_layoutisHND, or [num_pages, 1 or 2, page_size, num_kv_heads, head_dim] ifkv_layoutisNHD. If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, num_kv_heads, page_size, head_dim] ifkv_layoutisHND, or [num_pages, page_size, num_kv_heads, head_dim] ifkv_layoutisNHD. The first tensor is the key cache, and the second tensor is the value cache.Contiguity requirements (trtllm-gen backend):
The
head_dim(last dim) must have stride 1. This is a TMA hardware constraintThe head and batch/page dims can have arbitrary strides.
workspace_buffer (torch.Tensor. Must be initialized to 0 for its first use.) – workspace
block_tables (torch.Tensor) – Page table of kv cache. When
uses_shared_paged_kv_idxis True (default): shape[batch_size, max_num_pages_per_seq]. Whenuses_shared_paged_kv_idxis False: shape[batch_size, 2, max_num_pages_per_seq]where dim 1 distinguishes K (0) and V (1) page indices.seq_lens (torch.Tensor) – A uint32 1D tensor indicating the kv sequence length of each prompt. shape:
[batch_size]max_seq_len (int) – max sequence length for kv_cache
bmm1_scale (Union[float, torch.Tensor]) – fused scale for bmm1 input. when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.
bmm2_scale (Union[float, torch.Tensor]) – fused scale for bmm2 input. when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.
window_left (int = -1) – The left (inclusive) window size for the attention window, when set to
-1, the window size will be set to the full length of the sequence. Defaults to-1.out (Optional[Union[torch.Tensor, FP4Tensor]] = None) – output tensor, if not provided, will be allocated with
out_dtype, ifout_dtypeis not provided, will use the type ofquery.out_dtype (Optional[Union[torch.dtype, str]] = None) – output dtype, if not provided, will use the type of
out. For nvfp4, use stringnvfp4.o_sf_scale (Optional[float] = None) – scale for nvfp4 output tensor scale factor.
o_sf_vec_size (Optional[int] = None) – vector size for nvfp4 output tensor scale factor.
sinks (Optional[List[torch.Tensor]] = None) – additional value per head in the denominator of the softmax.
kv_layout (str = "HND") – The layout of the input k/v tensors, could be either
NHDorHND. Defaults toHND. For the trtllm-gen backend with NVFP4 KV cache, usingNHDwill trigger an automatic transpose and.contiguous()copy of both the KV data and block scale tensors to convert them to HND layout. This incurs extra memory allocation and data copy overhead. UseHNDfor better performance.enable_pdl (Optional[bool] = None) – Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization When set to
None, the backend will be chosen based on the device architecture and kernel availability.backend (str = "auto") – The implementation backend, could be
auto/xqaortrtllm-gen. Defaults toauto. When set toauto, the backend will be chosen based on the device architecture and kernel availability. For sm_100 and sm_103 (blackwell architecture),autowill choosetrtllm-genbackend. For sm_90 (hopper architecture) and sm_120/sm_121 (blackwell architecture),autowill choosexqabackend.o_scale (Optional[float] = 1.0) – output scale factor for xqa fp8 output.
mask (Optional[torch.Tensor] = None) – causal attention mask for xqa speculative decoding.
max_q_len (Optional[int] = None) – The maximum query sequence length across all requests when using variable-length queries. Only supported by trtllm-gen backend. Must be provided together with
cum_seq_lens_q. When None, all requests use uniform query length specified byq_len_per_req.cum_seq_lens_q (Optional[torch.Tensor] = None) – Cumulative query sequence lengths for variable-length query support, shape:
[batch_size + 1], dtype:torch.int32. Only supported by trtllm-gen backend. Must be provided together withmax_q_len. When None, all requests use uniform query length specified byq_len_per_req.kv_block_scales (Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None) –
Per-block scale factors for NVFP4 KV cache. Either a tuple of (k_scales, v_scales) or a single tensor with shape
[num_pages, 2, ...]that will be unbound along dim=1. Each scale tensor has shape[num_pages, num_kv_heads, page_size, head_dim // 16]in HND layout, with dtypetorch.float8_e4m3fn.Contiguity requirements (trtllm-gen backend):
The last two dims (
page_size,head_dim // 16) must be contiguous (i.e.,stride[-1] == 1andstride[-2] == head_dim // 16). This is because the kernel reshapes them into(16, page_size * head_dim / 16 / 16)to satisfy TMA’s 16-byte box width minimum.The head and batch/page dims can have arbitrary strides.
skip_softmax_threshold_scale_factor (Optional[float] = None) – threshold scale factor for skipping softmax operations. Providing a value for this parameter enables skip-softmax sparsity as described in: https://arxiv.org/abs/2512.12087 If no value is provided, then standard attention is used. Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation. The actual threshold value equals the provided threshold_scale_factor divided by the context length.
uses_shared_paged_kv_idx (bool = True) – Whether the K and V page indices are shared as a unified index. True (default) uses vLLM/FlashInfer layout with a 2D page table. False uses TRT-LLM layout with a 3D page table
[batch_size, 2, max_num_pages_per_seq].
- Returns:
out – output torch.Tensor or FP4Tensor.
- Return type:
Union[torch.Tensor, FP4Tensor]