flashinfer.decode.single_decode_with_kv_cache_with_jit_module¶
- flashinfer.decode.single_decode_with_kv_cache_with_jit_module(jit_module: Any, q: Tensor, k: Tensor, v: Tensor, *args, kv_layout: str = 'NHD', window_left: int = -1, return_lse: bool = False)¶
Single-request decode using a pre-compiled JIT module.
This is the low-level entry point used by
single_decode_with_kv_cache()after backend dispatch; user code should normally callsingle_decode_with_kv_cachedirectly. This function is exposed for advanced users who already hold a compiled JIT module (for example to customize the attention variant via extra kernel arguments).- Parameters:
jit_module (Any) – Compiled JIT module returned by one of the
gen_*_modulefactories.q (torch.Tensor) – Query tensor, shape
[num_qo_heads, head_dim].k (torch.Tensor) – Key tensor, shape
[kv_len, num_kv_heads, head_dim](NHD) or[num_kv_heads, kv_len, head_dim](HND).v (torch.Tensor) – Value tensor, layout matches
k.*args – Extra positional arguments forwarded to the JIT module’s
runsymbol (e.g. soft-cap or sliding-window parameters required by the chosen variant).kv_layout (str) – Layout of
kandv, either"NHD"or"HND". Defaults to"NHD".window_left (int) – Left window size for sliding-window attention;
-1disables it.return_lse (bool) – If
True, allocate an LSE buffer (shape(num_qo_heads,),float32) for the kernel to write into. Defaults toFalse. Note: the buffer is allocated and filled by the kernel but is not currently returned to the caller – this function always returns justo. Callers who need the LSE should usesingle_decode_with_kv_cache()instead.
- Returns:
Output tensor with shape matching
q. The LSE buffer (whenreturn_lse=True) is discarded; see the parameter note.- Return type:
torch.Tensor