flashinfer.decode.single_decode_with_kv_cache¶
- flashinfer.decode.single_decode_with_kv_cache(q: Tensor, k: Tensor, v: Tensor, kv_layout: str = 'NHD', pos_encoding_mode: str = 'NONE', use_tensor_cores: bool = False, q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, window_left: int = -1, logits_soft_cap: float | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, return_lse: Literal[False] = False) Tensor¶
 - flashinfer.decode.single_decode_with_kv_cache(q: Tensor, k: Tensor, v: Tensor, kv_layout: str = 'NHD', pos_encoding_mode: str = 'NONE', use_tensor_cores: bool = False, q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, window_left: int = -1, logits_soft_cap: float | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, return_lse: Literal[True] = True) Tuple[Tensor, Tensor]
 Decode attention with KV Cache for single request, return attention output.
- Parameters:
 q (torch.Tensor) – The query tensor, shape:
[num_qo_heads, head_dim].k (torch.Tensor) – The key tensor, shape:
[kv_len, num_kv_heads, head_dim]ifkv_layoutisNHD, or[num_kv_heads, kv_len, head_dim]ifkv_layoutisHND.v (torch.Tensor) – The value tensor, shape:
[kv_len, num_kv_heads, head_dim]ifkv_layoutisNHD, or[num_kv_heads, kv_len, head_dim]ifkv_layoutisHND.kv_layout (str) – The layout of the input k/v tensors, could be either
NHDorHND.pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be
NONE/ROPE_LLAMA(LLAMA style rotary embedding) /ALIBI. Defaults toNONE.use_tensor_cores (bool) – Whether to use tensor cores for the computation. Will be faster for large group size in grouped query attention. Defaults to
False.q_scale (Optional[float]) – The calibration scale of query for fp8 input, if not provided, will be set to
1.0.k_scale (Optional[float]) – The calibration scale of key for fp8 input, if not provided, will be set to
1.0.v_scale (Optional[float]) – The calibration scale of value for fp8 input, if not provided, will be set to
1.0.window_left (int) – The left (inclusive) window size for the attention window, when set to
-1, the window size will be set to the full length of the sequence. Defaults to-1.logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to
0. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.sm_scale (Optional[float]) – The scale of softmax, if not provided, will be set to
1 / sqrt(head_dim).rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to
1.0.rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to
1e4.return_lse (bool) – Whether to return the log sum exp value of the attention logits.
- Returns:
 If
return_lseisFalse, the attention output, shape:[qo_len, num_qo_heads, head_dim_vo]. Ifreturn_lseisTrue, a tuple of two tensors:The attention output, shape:
[num_qo_heads, head_dim_vo].The log sum exp value, shape:
[num_qo_heads].
- Return type:
 Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
Examples
>>> import torch >>> import flashinfer >>> kv_len = 4096 >>> num_qo_heads = 32 >>> num_kv_heads = 32 >>> head_dim = 128 >>> q = torch.randn(num_qo_heads, head_dim).half().to("cuda:0") >>> k = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0") >>> v = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0") >>> o = flashinfer.single_decode_with_kv_cache(q, k, v) >>> o.shape torch.Size([32, 128])
Note
The
num_qo_headsmust be a multiple ofnum_kv_heads. Ifnum_qo_headsis not equal tonum_kv_heads, the function will use grouped query attention.