flashinfer.prefill.trtllm_batch_context_with_kv_cache

flashinfer.prefill.trtllm_batch_context_with_kv_cache(query: torch.Tensor, kv_cache: torch.Tensor | Tuple[torch.Tensor, torch.Tensor], workspace_buffer: torch.Tensor, block_tables: torch.Tensor, seq_lens: torch.Tensor, max_q_len: int, max_kv_len: int, bmm1_scale: float, bmm2_scale: float, batch_size: int, cum_seq_lens_q: torch.Tensor, cum_seq_lens_kv: torch.Tensor, window_left: int = -1, out: torch.Tensor | FP4Tensor | None = None, out_dtype: torch.dtype | str | None = None, o_sf_scale: float | None = None, o_sf_vec_size: int | None = None, sinks: List[torch.Tensor] | None = None) torch.Tensor | FP4Tensor
Parameters:
  • query (torch.Tensor) – query tensor with shape [num_tokens, num_heads, head_dim]

  • kv_cache (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) – If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, num_kv_heads, page_size, head_dim]

  • workspace_buffer (torch.Tensor) – workspace

  • block_tables (torch.Tensor) – page_table of kv cache, [batch_size, num_pages]

  • seq_lens (torch.Tensor) – A uint32 1D tensor indicating the kv sequence length of each prompt. shape: [batch_size]

  • max_q_len (int) – max sequence length for query

  • max_kv_len (int) – max sequence length for kv_cache

  • bmm1_scale (float) – fused scale for bmm1 input.

  • bmm2_scale (float) – fused scale for bmm2 input.

  • batch_size (int) – batch size

  • cum_seq_lens_q (torch.Tensor) – cumulative sequence length for query. shape: [batch_size + 1]

  • cum_seq_lens_kv (torch.Tensor) – cumulative sequence length for kv_cache. shape: [batch_size + 1]

  • window_left (int = -1) – The left (inclusive) window size for the attention window, when set to -1, the window size will be set to the full length of the sequence. Defaults to -1.

  • out (Optional[Union[torch.Tensor, FP4Tensor]] = None) – output tensor, if not provided, will be allocated with out_dtype, if out_dtype is not provided, will use the type of query.

  • out_dtype (Optional[Union[torch.dtype, str]] = None) – output dtype, if not provided, will use the type of out. For nvfp4, use string nvfp4.

  • o_sf_scale (Optional[float] = None) – scale for nvfp4 output tensor scale factor.

  • o_sf_vec_size (Optional[int] = None) – vector size for nvfp4 output tensor scale factor.

  • sinks (Optional[List[torch.Tensor]] = None) – additional value per head in the denominator of the softmax.

Returns:

out – output torch.Tensor or FP4Tensor.

Return type:

Union[torch.Tensor, FP4Tensor]