flashinfer.prefill.single_prefill_with_kv_cache#

flashinfer.prefill.single_prefill_with_kv_cache(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, custom_mask: torch.Tensor | None = None, packed_custom_mask: torch.Tensor | None = None, causal: bool = False, kv_layout: str = 'NHD', pos_encoding_mode: str = 'NONE', allow_fp16_qk_reduction: bool = False, sm_scale: float | None = None, logits_soft_cap: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None)#

Prefill/Append attention with KV cache for single request, return the attention output.

Parameters:
  • q (torch.Tensor) – The query tensor, shape: [qo_len, num_qo_heads, head_dim].

  • k (torch.Tensor) – The key tensor, shape: [kv_len, num_kv_heads, head_dim] if kv_layout is NHD, or [num_kv_heads, kv_len, head_dim] if kv_layout is HND.

  • v (torch.Tensor) – The key tensor, shape: [kv_len, num_kv_heads, head_dim] if kv_layout is NHD, [num_kv_heads, kv_len, head_dim] if kv_layout is HND.

  • custom_mask (Optional[torch.Tensor]) –

    The custom boolean mask tensor, shape: [qo_len, kv_len]. The elements in the mask tensor should be either True or False, where False means the corresponding element in the attention matrix will be masked out.

    When custom_mask is provided, and packed_custom_mask is not, the function will pack the custom mask tensor into a 1D packed mask tensor, which introduces additional overhead.

  • packed_custom_mask (Optional[torch.Tensor]) – The 1D packed uint8 mask tensor, if provided, the custom_mask will be ignored. The packed mask tensor is generated by flashinfer.quantization.packbits().

  • causal (bool) – Whether to apply causal mask to the attention matrix. This is only effective when custom_mask is not provided.

  • kv_layout (str) – The layout of the input k/v tensors, could be either NHD or HND.

  • pos_encoding_mode (str) – The position encoding applied inside attention kernels, could be NONE/ROPE_LLAMA (LLAMA style rotary embedding) /ALIBI. Default is NONE.

  • allow_fp16_qk_reduction (bool) – Whether to use f16 for qk reduction (faster at the cost of slight precision loss).

  • logits_soft_cap (Optional[float]) – The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to 0. If greater than 0, the logits will be capped according to formula: \(\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})\), where \(x\) is the input logits.

  • sm_scale (Optional[float]) – The scale used in softmax, if not provided, will be set to 1.0 / sqrt(head_dim).

  • rope_scale (Optional[float]) – The scale used in RoPE interpolation, if not provided, will be set to 1.0.

  • rope_theta (Optional[float]) – The theta used in RoPE, if not provided, will be set to 1e4.

Returns:

The attention output, shape: [qo_len, num_qo_heads, head_dim].

Return type:

torch.Tensor

Examples

>>> import torch
>>> import flashinfer
>>> qo_len = 128
>>> kv_len = 4096
>>> num_qo_heads = 32
>>> num_kv_heads = 4
>>> head_dim = 128
>>> q = torch.randn(qo_len, num_qo_heads, head_dim).half().to("cuda:0")
>>> k = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0")
>>> v = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0")
>>> o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True,
        allow_fp16_qk_reduction=True)
>>> o.shape
torch.Size([128, 32, 128])
>>> mask = torch.tril(
>>>     torch.full((qo_len, kv_len), True, device="cuda:0"),
>>>     diagonal=(kv_len - qo_len),
>>> )
>>> mask
tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True, False, False],
        [ True,  True,  True,  ...,  True,  True, False],
        [ True,  True,  True,  ...,  True,  True,  True]], device='cuda:0')
>>> o_custom = flashinfer.single_prefill_with_kv_cache(q, k, v, custom_mask=mask)
>>> torch.allclose(o, o_custom, rtol=1e-3, atol=1e-3)
True

Notes

The num_qo_heads must be a multiple of num_kv_heads. If num_qo_heads is not equal to num_kv_heads, the function will use grouped query attention.