flashinfer.quantization.nvfp4_quantize_paged_kv_cache

flashinfer.quantization.nvfp4_quantize_paged_kv_cache(k_cache: Tensor, v_cache: Tensor, kv_layout: str = 'HND', k_global_sf: Tensor | None = None, v_global_sf: Tensor | None = None) Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], float, float]

Quantize a paged KV cache to NVFP4 for the trtllm-gen MHA kernel.

Quantizes BF16/FP16 K/V caches to NVFP4 with two-level scaling (global FP32 + per-block FP8) and swizzles the scale factors for the SM100 trtllm-gen MHA kernel layout.

Parameters:
  • k_cache (torch.Tensor) – Key cache tensor. HND layout: [num_pages, num_kv_heads, page_size, head_dim]; NHD layout: [num_pages, page_size, num_kv_heads, head_dim].

  • v_cache (torch.Tensor) – Value cache tensor (same layout as k_cache).

  • kv_layout (str) – Layout of the input KV cache, either "HND" or "NHD".

  • k_global_sf (torch.Tensor, optional) – Global scale factor for K (float32 scalar tensor). When None, auto-computed as FLOAT8_E4M3_MAX / k_amax.

  • v_global_sf (torch.Tensor, optional) – Global scale factor for V (float32 scalar tensor). When None, auto-computed as FLOAT8_E4M3_MAX / v_amax.

Returns:

(kv_cache_fp4, kv_cache_sf, k_global_scale, v_global_scale) where kv_cache_fp4 is (k_fp4, v_fp4) in the same layout as the input with head_dim replaced by head_dim // 2, dtype uint8; kv_cache_sf is (k_scales, v_scales) (k_scales keeps the linear input layout, v_scales uses TRT-LLM’s 4-token interleaved layout, both with head_dim replaced by head_dim // 16, dtype float8_e4m3fn); and the two trailing floats are 1 / k_global_sf and 1 / v_global_sf.

Return type:

Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], float, float]