flashinfer.prefill.single_prefill_with_kv_cache_with_jit_module

flashinfer.prefill.single_prefill_with_kv_cache_with_jit_module(jit_module: Any, q: Tensor, k: Tensor, v: Tensor, *args, kv_layout: str = 'NHD', mask_mode: int = 0, window_left: int = -1, return_lse: bool = False) Tensor | Tuple[Tensor, Tensor]

Single-request prefill / append attention using a pre-compiled JIT module.

Low-level entry point used by single_prefill_with_kv_cache() after backend dispatch; user code should normally call single_prefill_with_kv_cache directly. Exposed for advanced users who already hold a compiled JIT module (for example to customize the attention variant via extra kernel arguments).

Parameters:
  • jit_module (Any) – Compiled JIT module returned by one of the gen_*_module factories.

  • q (torch.Tensor) – Query tensor, shape [qo_len, num_qo_heads, head_dim_qk].

  • k (torch.Tensor) – Key tensor, shape [kv_len, num_kv_heads, head_dim_qk] (NHD) or [num_kv_heads, kv_len, head_dim_qk] (HND).

  • v (torch.Tensor) – Value tensor, layout matches k; last dimension may differ when head_dim_vo != head_dim_qk.

  • *args – Extra positional arguments forwarded to the JIT module’s run symbol (e.g. soft-cap, sink, or custom-mask buffers required by the chosen variant).

  • kv_layout (str) – Layout of k and v, either "NHD" or "HND". Defaults to "NHD".

  • mask_mode (int) – Mask mode, one of the values defined by MaskMode. Defaults to MaskMode.NON_CAUSAL.value.

  • window_left (int) – Left window size for sliding-window attention; -1 disables it.

  • return_lse (bool) – Whether to allocate and return the log-sum-exp tensor. Defaults to False.

Returns:

If return_lse is False, the attention output tensor. Otherwise the (output, lse) pair, where lse has shape [qo_len, num_qo_heads] and dtype float32.

Return type:

Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]