flashinfer.decode.cudnn_batch_decode_with_kv_cache¶
- flashinfer.decode.cudnn_batch_decode_with_kv_cache(q: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, scale: float, workspace_buffer: torch.Tensor, *, max_sequence_kv: int, actual_seq_lens_kv: torch.Tensor | None = None, block_tables: torch.Tensor | None = None, is_cuda_graph_compatible: bool = False, batch_offsets_q: torch.Tensor | None = None, batch_offsets_o: torch.Tensor | None = None, batch_offsets_k: torch.Tensor | None = None, batch_offsets_v: torch.Tensor | None = None, out: torch.Tensor | None = None) torch.Tensor ¶
Performs batched decode attention with paged KV cache using cuDNN.
- Parameters:
q – Query tensor of shape (batch_size, num_heads_qo, head_dim), seq_len_q is the maximum sequence length of queries in the batch
k_cache – Key cache tensor of shape (total_num_pages, num_heads_kv, page_size, head_dim)
v_cache – Value cache tensor of shape (total_num_pages, num_heads_kv, page_size, head_dim)
scale – Scaling factor for attention scores, typically 1/sqrt(head_dim)
workspace_buffer – Workspace buffer for cuDNN operations. Scales with batch size. 128 MB should be sufficient for most cases
max_sequence_kv – Maximum number of tokens per key/value sequence (s_kv_max)
actual_seq_lens_kv – Actual sequence lengths for key/values per batch, shape (batch_size,) on CPU
block_tables – Page table mapping for KV cache, shape (batch_size, num_pages_per_seq) on GPU
is_cuda_graph_compatible – Whether the decode operation is compatible with CUDA graph
batch_offsets – Optional batch offsets tensor of shape (batch_size,) on GPU
out – Optional pre-allocated output tensor
batch_offsets_q – Optional batch offsets for query tensor of shape (batch_size,) on GPU
batch_offsets_o – Optional batch offsets for output tensor of shape (batch_size,) on GPU
batch_offsets_k – Optional batch offsets for key tensor of shape (batch_size,) on GPU
batch_offsets_v – Optional batch offsets for value tensor of shape (batch_size,) on GPU
- Returns:
Output tensor of shape (batch_size, num_heads_qo, head_dim)
Note
Currently only supports causal attention (causal must be True) All tensors must be contiguous and on the same CUDA device Query and KV heads can have different sizes (num_heads_qo >= num_heads_kv)