flashinfer.cudnn.cudnn_batch_decode_with_kv_cache¶
- flashinfer.cudnn.cudnn_batch_decode_with_kv_cache(q: Tensor, k_cache: Tensor, v_cache: Tensor, scale: float, workspace_buffer: Tensor, *, max_sequence_kv: int, actual_seq_lens_kv: Tensor | None = None, block_tables: Tensor | None = None, is_cuda_graph_compatible: bool = False, batch_offsets_q: Tensor | None = None, batch_offsets_o: Tensor | None = None, batch_offsets_k: Tensor | None = None, batch_offsets_v: Tensor | None = None, out: Tensor | None = None) Tensor¶
Batched decode attention with paged KV cache, backed by cuDNN SDPA.
- Parameters:
q (torch.Tensor) – Query tensor of shape
(batch_size, num_heads_qo, head_dim).k_cache (torch.Tensor) – Key cache, shape
(total_num_pages, num_heads_kv, page_size, head_dim).v_cache (torch.Tensor) – Value cache, shape
(total_num_pages, num_heads_kv, page_size, head_dim).scale (float) – Softmax scaling factor, typically
1 / sqrt(head_dim).workspace_buffer (torch.Tensor) – Workspace buffer for cuDNN. Scales with batch size; 128 MB is sufficient for typical decode workloads.
max_sequence_kv (int) – Maximum number of tokens per KV sequence in the batch (
s_kv_max).actual_seq_lens_kv (Optional[torch.Tensor]) – Per-request KV lengths, shape
(batch_size,). When cuDNN is available (the default backend) this tensor must reside on the same CUDA device asq. Only the fallback non-cuDNN path accepts (and internally copies) a CPU tensor.block_tables (Optional[torch.Tensor]) – Page-table mapping for the paged KV cache, shape
(batch_size, num_pages_per_seq)on GPU.is_cuda_graph_compatible (bool) – Whether to plan the operation in a CUDA-graph-capture-safe mode.
batch_offsets_q (Optional[torch.Tensor]) – Per-request offsets into the query tensor, shape
(batch_size,)on GPU.batch_offsets_o (Optional[torch.Tensor]) – Per-request offsets into the output tensor, shape
(batch_size,)on GPU.batch_offsets_k (Optional[torch.Tensor]) – Per-request offsets into the key tensor, shape
(batch_size,)on GPU.batch_offsets_v (Optional[torch.Tensor]) – Per-request offsets into the value tensor, shape
(batch_size,)on GPU.out (Optional[torch.Tensor]) – Pre-allocated output tensor, shape
(batch_size, num_heads_qo, head_dim); allocated internally whenNone.
- Returns:
Output tensor of shape
(batch_size, num_heads_qo, head_dim).- Return type:
torch.Tensor
Note
Currently only supports causal attention; all tensors must be contiguous and on the same CUDA device. Query and KV heads may differ (
num_heads_qo >= num_heads_kv, multi-query / grouped-query attention).