flashinfer.decode.single_decode_with_kv_cache

flashinfer.decode.single_decode_with_kv_cache(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, kv_layout: str = 'NHD', pos_encoding_mode: str = 'NONE', use_tensor_cores: bool = False, q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, window_left: int = -1, logits_soft_cap: float | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None) torch.Tensor

Decode attention with KV Cache for single request, return attention output.

Parameters:
  • q (torch.Tensor) – The query tensor, shape: [num_qo_heads, head_dim].

  • k (torch.Tensor) – The key tensor, shape: [kv_len, num_kv_heads, head_dim] if kv_layout is NHD, or [num_kv_heads, kv_len, head_dim] if kv_layout is HND.

  • v (torch.Tensor) – The value tensor, shape: [kv_len, num_kv_heads, head_dim] if kv_layout is NHD, or [num_kv_heads, kv_len, head_dim] if kv_layout is HND.

  • kv_layout (str) – The layout of the input k/v tensors, could be either NHD or HND.

  • pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be NONE/ROPE_LLAMA (LLAMA style rotary embedding) /ALIBI. Defaults to NONE.

  • use_tensor_cores (bool) – Whether to use tensor cores for the computation. Will be faster for large group size in grouped query attention. Defaults to False.

  • q_scale (Optional[float]) – The calibration scale of query for fp8 input, if not provided, will be set to 1.0.

  • k_scale (Optional[float]) – The calibration scale of key for fp8 input, if not provided, will be set to 1.0.

  • v_scale (Optional[float]) – The calibration scale of value for fp8 input, if not provided, will be set to 1.0.

  • window_left (int) – The left (inclusive) window size for the attention window, when set to -1, the window size will be set to the full length of the sequence. Defaults to -1.

  • logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to 0. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.

  • sm_scale (Optional[float]) – The scale of softmax, if not provided, will be set to 1 / sqrt(head_dim).

  • rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to 1.0.

  • rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to 1e4.

Returns:

The attention output, shape: [num_qo_heads, head_dim]

Return type:

torch.Tensor

Examples

>>> import torch
>>> import flashinfer
>>> kv_len = 4096
>>> num_qo_heads = 32
>>> num_kv_heads = 32
>>> head_dim = 128
>>> q = torch.randn(num_qo_heads, head_dim).half().to("cuda:0")
>>> k = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0")
>>> v = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0")
>>> o = flashinfer.single_decode_with_kv_cache(q, k, v)
>>> o.shape
torch.Size([32, 128])

Note

The num_qo_heads must be a multiple of num_kv_heads. If num_qo_heads is not equal to num_kv_heads, the function will use grouped query attention.