flashinfer.cudnn.cudnn_batch_prefill_with_kv_cache

flashinfer.cudnn.cudnn_batch_prefill_with_kv_cache(q: Tensor, k_cache: Tensor, v_cache: Tensor, scale: float, workspace_buffer: Tensor, *, max_token_per_sequence: int, max_sequence_kv: int, actual_seq_lens_q: Tensor, actual_seq_lens_kv: Tensor, block_tables: Tensor | None = None, causal: bool, return_lse: bool, q_scale: Tensor | None = None, k_scale: Tensor | None = None, v_scale: Tensor | None = None, batch_offsets_q: Tensor | None = None, batch_offsets_o: Tensor | None = None, batch_offsets_k: Tensor | None = None, batch_offsets_v: Tensor | None = None, batch_offsets_stats: Tensor | None = None, out: Tensor | None = None, lse: Tensor | None = None, is_cuda_graph_compatible: bool = False, backend: str | None = None, o_data_type: dtype | None = None) tuple[Tensor, Tensor | None]

Batched prefill attention with paged KV cache, backed by cuDNN SDPA.

Parameters:
  • q (torch.Tensor) – Packed query tensor with shape (total_qo_tokens, num_heads_qo, head_dim_qk).

  • k_cache (torch.Tensor) – Key cache. If paged: (total_num_pages, num_heads_kv, page_size, head_dim_qk); otherwise (total_kv_tokens, num_heads_kv, head_dim_qk).

  • v_cache (torch.Tensor) – Value cache. If paged: (total_num_pages, num_heads_kv, page_size, head_dim_vo); otherwise (total_kv_tokens, num_heads_kv, head_dim_vo).

  • scale (float) – Softmax scaling factor, typically 1 / sqrt(head_dim_qk).

  • workspace_buffer (torch.Tensor) – Workspace buffer for cuDNN. Scales with batch size; 128 MB is sufficient for typical prefill workloads.

  • max_token_per_sequence (int) – Maximum number of tokens per query sequence (s_qo_max).

  • max_sequence_kv (int) – Maximum number of tokens per KV sequence (s_kv_max).

  • actual_seq_lens_q (torch.Tensor) – Per-request query lengths, shape (batch_size,). When cuDNN is available (the default backend) this tensor must reside on the same CUDA device as q. Only the fallback non-cuDNN path accepts (and internally copies) a CPU tensor; that fallback is also the only path that requires a CPU tensor when is_cuda_graph_compatible is False.

  • actual_seq_lens_kv (torch.Tensor) – Per-request KV lengths, shape (batch_size,). Same device rules as actual_seq_lens_q.

  • block_tables (Optional[torch.Tensor]) – Paged KV block table, shape (batch_size, num_pages_per_seq) on GPU. Pass None for non-paged KV layouts.

  • causal (bool) – Whether to apply a causal mask.

  • return_lse (bool) – Whether to return the log-sum-exp tensor (currently must be True in the cubin backend).

  • q_scale (Optional[torch.Tensor]) – FP8 dequantization scale for the query, shape (1, 1, 1, 1) on GPU.

  • k_scale (Optional[torch.Tensor]) – FP8 dequantization scale for the key, shape (1, 1, 1, 1) on GPU.

  • v_scale (Optional[torch.Tensor]) – FP8 dequantization scale for the value, shape (1, 1, 1, 1) on GPU.

  • batch_offsets_q (Optional[torch.Tensor]) – Per-request offsets into the query tensor, shape (batch_size,) on GPU.

  • batch_offsets_o (Optional[torch.Tensor]) – Per-request offsets into the output tensor, shape (batch_size,) on GPU.

  • batch_offsets_k (Optional[torch.Tensor]) – Per-request offsets into the key tensor, shape (batch_size,) on GPU.

  • batch_offsets_v (Optional[torch.Tensor]) – Per-request offsets into the value tensor, shape (batch_size,) on GPU.

  • batch_offsets_stats (Optional[torch.Tensor]) – Per-request offsets into the LSE / stats tensor, shape (batch_size,).

  • out (Optional[torch.Tensor]) – Pre-allocated output tensor, shape (total_qo_tokens, num_heads_qo, head_dim_vo). Allocated internally when None.

  • lse (Optional[torch.Tensor]) – Pre-allocated LSE tensor, shape (batch_size, max_token_per_sequence, num_heads_qo). Allocated internally when None and return_lse is True.

  • is_cuda_graph_compatible (bool) – Whether to plan the operation in a CUDA-graph-capture-safe mode.

  • backend (Optional[str]) – Optional cuDNN backend selector (e.g. "cubin"). When None, autodetects based on cuDNN availability.

  • o_data_type (Optional[torch.dtype]) – Optional output dtype; defaults to q.dtype.

Returns:

(output, lse) where output has shape (total_qo_tokens, num_heads_qo, head_dim_vo); lse has shape (batch_size, max_token_per_sequence, num_heads_qo) when return_lse=True, else None.

Return type:

Tuple[torch.Tensor, Optional[torch.Tensor]]

Note

Query and KV heads may differ (num_heads_qo >= num_heads_kv, MQA / GQA). When using CUDA graph capture, actual_seq_lens_q and actual_seq_lens_kv must reside on the same device as q. head_dim_qk must be 128 or 192, and head_dim_vo must be 128.