flashinfer.prefill.trtllm_batch_context_with_kv_cache¶
- flashinfer.prefill.trtllm_batch_context_with_kv_cache(query: Tensor, kv_cache: Tensor | Tuple[Tensor, Tensor], workspace_buffer: Tensor, block_tables: Tensor, seq_lens: Tensor, max_q_len: int, max_kv_len: int, bmm1_scale: float | Tensor, bmm2_scale: float | Tensor, batch_size: int, cum_seq_lens_q: Tensor, cum_seq_lens_kv: Tensor, window_left: int = -1, out: Tensor | FP4Tensor | None = None, out_dtype: str | dtype | None = None, o_sf_scale: float | None = None, o_sf_vec_size: int | None = None, kv_layout: str = 'HND', enable_pdl: bool | None = None, sinks: List[Tensor] | None = None, kv_block_scales: Tensor | Tuple[Tensor, Tensor] | None = None, skip_softmax_threshold_scale_factor: float | None = None, uses_shared_paged_kv_idx: bool = True) Tensor | FP4Tensor¶
- Parameters:
query (torch.Tensor) – query tensor with shape [num_tokens, num_heads, head_dim]
kv_cache (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) –
If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] if
kv_layoutis “HND”, or [num_pages, 1 or 2, page_size, num_kv_heads, head_dim] ifkv_layoutis “NHD”. If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, num_kv_heads, page_size, head_dim] ifkv_layoutis “HND”, or [num_pages, page_size, num_kv_heads, head_dim] ifkv_layoutis “NHD”. The first tensor is the key cache, the second tensor is the value cache.Contiguity requirements (trtllm-gen backend):
The
head_dim(last dim) must have stride 1. This is a TMA hardware constraintThe head and batch/page dims can have arbitrary strides.
workspace_buffer (torch.Tensor. Must be initialized to 0 for its first use.) – workspace
block_tables (torch.Tensor) – Page table of kv cache. When
uses_shared_paged_kv_idxis True (default): shape[batch_size, max_num_pages_per_seq]. Whenuses_shared_paged_kv_idxis False: shape[batch_size, 2, max_num_pages_per_seq]where dim 1 distinguishes K (0) and V (1) page indices.seq_lens (torch.Tensor) – A uint32 1D tensor indicating the kv sequence length of each prompt. shape:
[batch_size]max_q_len (int) – max sequence length for query
max_kv_len (int) – max sequence length for kv_cache
bmm1_scale (Union[float, torch.Tensor]) – fused scale for bmm1 input. when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.
bmm2_scale (Union[float, torch.Tensor]) – fused scale for bmm2 input. when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.
batch_size (int) – batch size
cum_seq_lens_q (torch.Tensor) – cumulative sequence length for query. shape:
[batch_size + 1]cum_seq_lens_kv (torch.Tensor) – cumulative sequence length for kv_cache. shape:
[batch_size + 1]window_left (int = -1) – The left (inclusive) window size for the attention window, when set to
-1, the window size will be set to the full length of the sequence. Defaults to-1.out (Optional[Union[torch.Tensor, FP4Tensor]] = None) – output tensor, if not provided, will be allocated with
out_dtype, ifout_dtypeis not provided, will use the type ofquery.out_dtype (Optional[Union[torch.dtype, str]] = None) – output dtype, if not provided, will use the type of
out. For nvfp4, use stringnvfp4.o_sf_scale (Optional[float] = None) – scale for nvfp4 output tensor scale factor.
o_sf_vec_size (Optional[int] = None) – vector size for nvfp4 output tensor scale factor.
enable_pdl (Optional[bool] = None) – Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization Defaults to
None, which means it will be enabled if the device supports PDL.kv_layout (str = "HND") – Layout of kv-cache, can be “HND” or “NHD”, default is “HND”. For the trtllm-gen backend with NVFP4 KV cache, using
NHDwill trigger an automatic transpose and.contiguous()copy of both the KV data and block scale tensors to convert them to HND layout. This incurs extra memory allocation and data copy overhead. UseHNDfor better performance.sinks (Optional[List[torch.Tensor]] = None) – additional value per head in the denominator of the softmax.
kv_block_scales (Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None) –
Per-block scale factors for NVFP4 KV cache. Either a tuple of (k_scales, v_scales) or a single tensor with shape
[num_pages, 2, ...]that will be unbound along dim=1. Each scale tensor has shape[num_pages, num_kv_heads, page_size, head_dim // 16]in HND layout, with dtypetorch.float8_e4m3fn.Contiguity requirements (trtllm-gen backend):
The last two dims (
page_size,head_dim // 16) must be contiguous (i.e.,stride[-1] == 1andstride[-2] == head_dim // 16). This is because the kernel reshapes them into(16, page_size * head_dim / 16 / 16)to satisfy TMA’s 16-byte box width minimum.The head and batch/page dims can have arbitrary strides.
skip_softmax_threshold_scale_factor (Optional[float] = None) – threshold scale factor for skipping softmax operations. Providing a value for this parameter enables skip-softmax sparsity as described in: https://arxiv.org/abs/2512.12087 If no value is provided, then standard attention is used. Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation. The actual threshold value equals the provided threshold_scale_factor divided by the context length.
uses_shared_paged_kv_idx (bool = True) – Whether the K and V page indices are shared as a unified index. True (default) uses vLLM/FlashInfer layout with a 2D page table. False uses TRT-LLM layout with a 3D page table
[batch_size, 2, max_num_pages_per_seq].
- Returns:
out – output torch.Tensor or FP4Tensor.
- Return type:
Union[torch.Tensor, FP4Tensor]