flashinfer.decode.single_decode_with_kv_cache_with_jit_module

flashinfer.decode.single_decode_with_kv_cache_with_jit_module(jit_module: Any, q: Tensor, k: Tensor, v: Tensor, *args, kv_layout: str = 'NHD', window_left: int = -1, return_lse: bool = False)

Single-request decode using a pre-compiled JIT module.

This is the low-level entry point used by single_decode_with_kv_cache() after backend dispatch; user code should normally call single_decode_with_kv_cache directly. This function is exposed for advanced users who already hold a compiled JIT module (for example to customize the attention variant via extra kernel arguments).

Parameters:
  • jit_module (Any) – Compiled JIT module returned by one of the gen_*_module factories.

  • q (torch.Tensor) – Query tensor, shape [num_qo_heads, head_dim].

  • k (torch.Tensor) – Key tensor, shape [kv_len, num_kv_heads, head_dim] (NHD) or [num_kv_heads, kv_len, head_dim] (HND).

  • v (torch.Tensor) – Value tensor, layout matches k.

  • *args – Extra positional arguments forwarded to the JIT module’s run symbol (e.g. soft-cap or sliding-window parameters required by the chosen variant).

  • kv_layout (str) – Layout of k and v, either "NHD" or "HND". Defaults to "NHD".

  • window_left (int) – Left window size for sliding-window attention; -1 disables it.

  • return_lse (bool) – If True, allocate an LSE buffer (shape (num_qo_heads,), float32) for the kernel to write into. Defaults to False. Note: the buffer is allocated and filled by the kernel but is not currently returned to the caller – this function always returns just o. Callers who need the LSE should use single_decode_with_kv_cache() instead.

Returns:

Output tensor with shape matching q. The LSE buffer (when return_lse=True) is discarded; see the parameter note.

Return type:

torch.Tensor