flashinfer.prefill.single_prefill_with_kv_cache_with_jit_module¶
- flashinfer.prefill.single_prefill_with_kv_cache_with_jit_module(jit_module: Any, q: Tensor, k: Tensor, v: Tensor, *args, kv_layout: str = 'NHD', mask_mode: int = 0, window_left: int = -1, return_lse: bool = False) Tensor | Tuple[Tensor, Tensor]¶
Single-request prefill / append attention using a pre-compiled JIT module.
Low-level entry point used by
single_prefill_with_kv_cache()after backend dispatch; user code should normally callsingle_prefill_with_kv_cachedirectly. Exposed for advanced users who already hold a compiled JIT module (for example to customize the attention variant via extra kernel arguments).- Parameters:
jit_module (Any) – Compiled JIT module returned by one of the
gen_*_modulefactories.q (torch.Tensor) – Query tensor, shape
[qo_len, num_qo_heads, head_dim_qk].k (torch.Tensor) – Key tensor, shape
[kv_len, num_kv_heads, head_dim_qk](NHD) or[num_kv_heads, kv_len, head_dim_qk](HND).v (torch.Tensor) – Value tensor, layout matches
k; last dimension may differ whenhead_dim_vo != head_dim_qk.*args – Extra positional arguments forwarded to the JIT module’s
runsymbol (e.g. soft-cap, sink, or custom-mask buffers required by the chosen variant).kv_layout (str) – Layout of
kandv, either"NHD"or"HND". Defaults to"NHD".mask_mode (int) – Mask mode, one of the values defined by
MaskMode. Defaults toMaskMode.NON_CAUSAL.value.window_left (int) – Left window size for sliding-window attention;
-1disables it.return_lse (bool) – Whether to allocate and return the log-sum-exp tensor. Defaults to
False.
- Returns:
If
return_lseisFalse, the attention output tensor. Otherwise the(output, lse)pair, wherelsehas shape[qo_len, num_qo_heads]and dtypefloat32.- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]