flashinfer.decode.single_decode_with_kv_cache¶
- flashinfer.decode.single_decode_with_kv_cache(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, kv_layout: str = 'NHD', pos_encoding_mode: str = 'NONE', use_tensor_cores: bool = False, q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, window_left: int = -1, logits_soft_cap: float | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None) torch.Tensor ¶
Decode attention with KV Cache for single request, return attention output.
- Parameters:
q (torch.Tensor) – The query tensor, shape:
[num_qo_heads, head_dim]
.k (torch.Tensor) – The key tensor, shape:
[kv_len, num_kv_heads, head_dim]
ifkv_layout
isNHD
, or[num_kv_heads, kv_len, head_dim]
ifkv_layout
isHND
.v (torch.Tensor) – The value tensor, shape:
[kv_len, num_kv_heads, head_dim]
ifkv_layout
isNHD
, or[num_kv_heads, kv_len, head_dim]
ifkv_layout
isHND
.kv_layout (str) – The layout of the input k/v tensors, could be either
NHD
orHND
.pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be
NONE
/ROPE_LLAMA
(LLAMA style rotary embedding) /ALIBI
. Defaults toNONE
.use_tensor_cores (bool) – Whether to use tensor cores for the computation. Will be faster for large group size in grouped query attention. Defaults to
False
.q_scale (Optional[float]) – The calibration scale of query for fp8 input, if not provided, will be set to
1.0
.k_scale (Optional[float]) – The calibration scale of key for fp8 input, if not provided, will be set to
1.0
.v_scale (Optional[float]) – The calibration scale of value for fp8 input, if not provided, will be set to
1.0
.window_left (int) – The left (inclusive) window size for the attention window, when set to
-1
, the window size will be set to the full length of the sequence. Defaults to-1
.logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to
0
. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.sm_scale (Optional[float]) – The scale of softmax, if not provided, will be set to
1 / sqrt(head_dim)
.rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to
1.0
.rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to
1e4
.
- Returns:
The attention output, shape:
[num_qo_heads, head_dim]
- Return type:
torch.Tensor
Examples
>>> import torch >>> import flashinfer >>> kv_len = 4096 >>> num_qo_heads = 32 >>> num_kv_heads = 32 >>> head_dim = 128 >>> q = torch.randn(num_qo_heads, head_dim).half().to("cuda:0") >>> k = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0") >>> v = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0") >>> o = flashinfer.single_decode_with_kv_cache(q, k, v) >>> o.shape torch.Size([32, 128])
Note
The
num_qo_heads
must be a multiple ofnum_kv_heads
. Ifnum_qo_heads
is not equal tonum_kv_heads
, the function will use grouped query attention.