flashinfer.prefill.single_prefill_with_kv_cache_return_lse#
- flashinfer.prefill.single_prefill_with_kv_cache_return_lse(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, custom_mask: torch.Tensor | None = None, packed_custom_mask: torch.Tensor | None = None, causal: bool = False, kv_layout: str = 'NHD', pos_encoding_mode: str = 'NONE', allow_fp16_qk_reduction: bool = False, sm_scale: float | None = None, window_left: int = -1, logits_soft_cap: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, *, return_lse: bool = True) torch.Tensor | Tuple[torch.Tensor, torch.Tensor] #
Prefill/Append attention with KV cache for single request, return the attention output.
- Parameters:
q (torch.Tensor) – The query tensor, shape:
[qo_len, num_qo_heads, head_dim]
.k (torch.Tensor) – The key tensor, shape:
[kv_len, num_kv_heads, head_dim]
ifkv_layout
isNHD
, or[num_kv_heads, kv_len, head_dim]
ifkv_layout
isHND
.v (torch.Tensor) – The key tensor, shape:
[kv_len, num_kv_heads, head_dim]
ifkv_layout
isNHD
,[num_kv_heads, kv_len, head_dim]
ifkv_layout
isHND
.custom_mask (Optional[torch.Tensor]) –
The custom boolean mask tensor, shape:
[qo_len, kv_len]
. The elements in the mask tensor should be eitherTrue
orFalse
, whereFalse
means the corresponding element in the attention matrix will be masked out.When
custom_mask
is provided, andpacked_custom_mask
is not, the function will pack the custom mask tensor into a 1D packed mask tensor, which introduces additional overhead.packed_custom_mask (Optional[torch.Tensor]) – The 1D packed uint8 mask tensor, if provided, the
custom_mask
will be ignored. The packed mask tensor is generated byflashinfer.quantization.packbits()
.causal (bool) – Whether to apply causal mask to the attention matrix. This is only effective when
custom_mask
is not provided.kv_layout (str) – The layout of the input k/v tensors, could be either
NHD
orHND
.pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be
NONE
/ROPE_LLAMA
(LLAMA style rotary embedding) /ALIBI
. Default isNONE
.allow_fp16_qk_reduction (bool) – Whether to use f16 for qk reduction (faster at the cost of slight precision loss).
window_left (int) – The left (inclusive) window size for the attention window, when set to
-1
, the window size will be set to the full length of the sequence. Defaults to-1
.logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to
0
. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.sm_scale (Optional[float]) – The scale used in softmax, if not provided, will be set to
1.0 / sqrt(head_dim)
.rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to 1.0.
rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to 1e4.
return_lse (bool) – Whether to return the log sum exp value of the attention logits.
- Returns:
If
return_lse
isFalse
, the attention output, shape:[qo_len, num_qo_heads, head_dim]
. Ifreturn_lse
isTrue
, a tuple of two tensors:The attention output, shape:
[qo_len, num_qo_heads, head_dim]
.The log sum exp value, shape:
[qo_len, num_qo_heads]
.
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
Examples
>>> import torch >>> import flashinfer >>> qo_len = 128 >>> kv_len = 4096 >>> num_qo_heads = 32 >>> num_kv_heads = 4 >>> head_dim = 128 >>> q = torch.randn(qo_len, num_qo_heads, head_dim).half().to("cuda:0") >>> k = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0") >>> v = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0") >>> o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, allow_fp16_qk_reduction=True) >>> o.shape torch.Size([128, 32, 128]) >>> mask = torch.tril( >>> torch.full((qo_len, kv_len), True, device="cuda:0"), >>> diagonal=(kv_len - qo_len), >>> ) >>> mask tensor([[ True, True, True, ..., False, False, False], [ True, True, True, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]], device='cuda:0') >>> o_custom = flashinfer.single_prefill_with_kv_cache(q, k, v, custom_mask=mask) >>> torch.allclose(o, o_custom, rtol=1e-3, atol=1e-3) True
Note
The
num_qo_heads
must be a multiple ofnum_kv_heads
. Ifnum_qo_heads
is not equal tonum_kv_heads
, the function will use grouped query attention.