flashinfer.prefill.trtllm_batch_context_with_kv_cache¶
- flashinfer.prefill.trtllm_batch_context_with_kv_cache(query: torch.Tensor, kv_cache: torch.Tensor | Tuple[torch.Tensor, torch.Tensor], workspace_buffer: torch.Tensor, block_tables: torch.Tensor, seq_lens: torch.Tensor, max_q_len: int, max_kv_len: int, bmm1_scale: float, bmm2_scale: float, batch_size: int, cum_seq_lens_q: torch.Tensor, cum_seq_lens_kv: torch.Tensor, window_left: int = -1, out: torch.Tensor | FP4Tensor | None = None, out_dtype: torch.dtype | str | None = None, o_sf_scale: float | None = None, o_sf_vec_size: int | None = None, sinks: List[torch.Tensor] | None = None) torch.Tensor | FP4Tensor ¶
- Parameters:
query (torch.Tensor) – query tensor with shape [num_tokens, num_heads, head_dim]
kv_cache (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) – If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, num_kv_heads, page_size, head_dim]
workspace_buffer (torch.Tensor) – workspace
block_tables (torch.Tensor) – page_table of kv cache, [batch_size, num_pages]
seq_lens (torch.Tensor) – A uint32 1D tensor indicating the kv sequence length of each prompt. shape:
[batch_size]
max_q_len (int) – max sequence length for query
max_kv_len (int) – max sequence length for kv_cache
bmm1_scale (float) – fused scale for bmm1 input.
bmm2_scale (float) – fused scale for bmm2 input.
batch_size (int) – batch size
cum_seq_lens_q (torch.Tensor) – cumulative sequence length for query. shape:
[batch_size + 1]
cum_seq_lens_kv (torch.Tensor) – cumulative sequence length for kv_cache. shape:
[batch_size + 1]
window_left (int = -1) – The left (inclusive) window size for the attention window, when set to
-1
, the window size will be set to the full length of the sequence. Defaults to-1
.out (Optional[Union[torch.Tensor, FP4Tensor]] = None) – output tensor, if not provided, will be allocated with
out_dtype
, ifout_dtype
is not provided, will use the type ofquery
.out_dtype (Optional[Union[torch.dtype, str]] = None) – output dtype, if not provided, will use the type of
out
. For nvfp4, use stringnvfp4
.o_sf_scale (Optional[float] = None) – scale for nvfp4 output tensor scale factor.
o_sf_vec_size (Optional[int] = None) – vector size for nvfp4 output tensor scale factor.
sinks (Optional[List[torch.Tensor]] = None) – additional value per head in the denominator of the softmax.
- Returns:
out – output torch.Tensor or FP4Tensor.
- Return type:
Union[torch.Tensor, FP4Tensor]