flashinfer.decode.trtllm_batch_decode_with_kv_cache

flashinfer.decode.trtllm_batch_decode_with_kv_cache(query: Tensor, kv_cache: Tensor | Tuple[Tensor, Tensor], workspace_buffer: Tensor, block_tables: Tensor, seq_lens: Tensor, max_seq_len: int, bmm1_scale: float | Tensor = 1.0, bmm2_scale: float | Tensor = 1.0, window_left: int = -1, out: Tensor | FP4Tensor | None = None, out_dtype: str | dtype | None = None, o_sf_scale: float | None = None, o_sf_vec_size: int | None = None, sinks: List[Tensor] | None = None, kv_layout: str = 'HND', enable_pdl: bool | None = None, backend: str = 'auto', q_len_per_req: int | None = 1, o_scale: float | None = 1.0, mask: Tensor | None = None, max_q_len: int | None = None, cum_seq_lens_q: Tensor | None = None, kv_block_scales: Tensor | Tuple[Tensor, Tensor] | None = None, skip_softmax_threshold_scale_factor: float | None = None, kv_cache_sf: Tuple[Tensor, Tensor] | None = None, uses_shared_paged_kv_idx: bool = True) Tensor | FP4Tensor
Parameters:
  • query (torch.Tensor) – query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = total query tokens in the batch.

  • kv_cache (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) –

    If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] if kv_layout is HND, or [num_pages, 1 or 2, page_size, num_kv_heads, head_dim] if kv_layout is NHD. If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, num_kv_heads, page_size, head_dim] if kv_layout is HND, or [num_pages, page_size, num_kv_heads, head_dim] if kv_layout is NHD. The first tensor is the key cache, and the second tensor is the value cache.

    Contiguity requirements (trtllm-gen backend):

    • The head_dim (last dim) must have stride 1. This is a TMA hardware constraint

    • The head and batch/page dims can have arbitrary strides.

  • workspace_buffer (torch.Tensor. Must be initialized to 0 for its first use.) – workspace

  • block_tables (torch.Tensor) – Page table of kv cache. When uses_shared_paged_kv_idx is True (default): shape [batch_size, max_num_pages_per_seq]. When uses_shared_paged_kv_idx is False: shape [batch_size, 2, max_num_pages_per_seq] where dim 1 distinguishes K (0) and V (1) page indices.

  • seq_lens (torch.Tensor) – A uint32 1D tensor indicating the kv sequence length of each prompt. shape: [batch_size]

  • max_seq_len (int) – max sequence length for kv_cache

  • bmm1_scale (Union[float, torch.Tensor]) – fused scale for bmm1 input. when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.

  • bmm2_scale (Union[float, torch.Tensor]) – fused scale for bmm2 input. when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.

  • window_left (int = -1) – The left (inclusive) window size for the attention window, when set to -1, the window size will be set to the full length of the sequence. Defaults to -1.

  • out (Optional[Union[torch.Tensor, FP4Tensor]] = None) – output tensor, if not provided, will be allocated with out_dtype, if out_dtype is not provided, will use the type of query.

  • out_dtype (Optional[Union[torch.dtype, str]] = None) – output dtype, if not provided, will use the type of out. For nvfp4, use string nvfp4.

  • o_sf_scale (Optional[float] = None) – scale for nvfp4 output tensor scale factor.

  • o_sf_vec_size (Optional[int] = None) – vector size for nvfp4 output tensor scale factor.

  • sinks (Optional[List[torch.Tensor]] = None) – additional value per head in the denominator of the softmax.

  • kv_layout (str = "HND") – The layout of the input k/v tensors, could be either NHD or HND. Defaults to HND. For the trtllm-gen backend with NVFP4 KV cache, using NHD will trigger an automatic transpose and .contiguous() copy of both the KV data and block scale tensors to convert them to HND layout. This incurs extra memory allocation and data copy overhead. Use HND for better performance.

  • enable_pdl (Optional[bool] = None) – Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization When set to None, the backend will be chosen based on the device architecture and kernel availability.

  • backend (str = "auto") – The implementation backend, could be auto/xqa or trtllm-gen. Defaults to auto. When set to auto, the backend will be chosen based on the device architecture and kernel availability. For sm_100 and sm_103 (blackwell architecture), auto will choose trtllm-gen backend. For sm_90 (hopper architecture) and sm_120/sm_121 (blackwell architecture), auto will choose xqa backend.

  • o_scale (Optional[float] = 1.0) – output scale factor for xqa fp8 output.

  • mask (Optional[torch.Tensor] = None) – causal attention mask for xqa speculative decoding.

  • max_q_len (Optional[int] = None) – The maximum query sequence length across all requests when using variable-length queries. Only supported by trtllm-gen backend. Must be provided together with cum_seq_lens_q. When None, all requests use uniform query length specified by q_len_per_req.

  • cum_seq_lens_q (Optional[torch.Tensor] = None) – Cumulative query sequence lengths for variable-length query support, shape: [batch_size + 1], dtype: torch.int32. Only supported by trtllm-gen backend. Must be provided together with max_q_len. When None, all requests use uniform query length specified by q_len_per_req.

  • kv_block_scales (Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None) –

    Per-block scale factors for NVFP4 KV cache. Either a tuple of (k_scales, v_scales) or a single tensor with shape [num_pages, 2, ...] that will be unbound along dim=1. Each scale tensor has shape [num_pages, num_kv_heads, page_size, head_dim // 16] in HND layout, with dtype torch.float8_e4m3fn.

    Contiguity requirements (trtllm-gen backend):

    • The last two dims (page_size, head_dim // 16) must be contiguous (i.e., stride[-1] == 1 and stride[-2] == head_dim // 16). This is because the kernel reshapes them into (16, page_size * head_dim / 16 / 16) to satisfy TMA’s 16-byte box width minimum.

    • The head and batch/page dims can have arbitrary strides.

  • skip_softmax_threshold_scale_factor (Optional[float] = None) – threshold scale factor for skipping softmax operations. Providing a value for this parameter enables skip-softmax sparsity as described in: https://arxiv.org/abs/2512.12087 If no value is provided, then standard attention is used. Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation. The actual threshold value equals the provided threshold_scale_factor divided by the context length.

  • uses_shared_paged_kv_idx (bool = True) – Whether the K and V page indices are shared as a unified index. True (default) uses vLLM/FlashInfer layout with a 2D page table. False uses TRT-LLM layout with a 3D page table [batch_size, 2, max_num_pages_per_seq].

Returns:

out – output torch.Tensor or FP4Tensor.

Return type:

Union[torch.Tensor, FP4Tensor]