flashinfer.quantization.nvfp4_quantize_paged_kv_cache¶
- flashinfer.quantization.nvfp4_quantize_paged_kv_cache(k_cache: Tensor, v_cache: Tensor, kv_layout: str = 'HND', k_global_sf: Tensor | None = None, v_global_sf: Tensor | None = None) Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], float, float]¶
Quantize a paged KV cache to NVFP4 for the trtllm-gen MHA kernel.
Quantizes BF16/FP16 K/V caches to NVFP4 with two-level scaling (global FP32 + per-block FP8) and swizzles the scale factors for the SM100 trtllm-gen MHA kernel layout.
- Parameters:
k_cache (torch.Tensor) – Key cache tensor.
HNDlayout:[num_pages, num_kv_heads, page_size, head_dim];NHDlayout:[num_pages, page_size, num_kv_heads, head_dim].v_cache (torch.Tensor) – Value cache tensor (same layout as
k_cache).kv_layout (str) – Layout of the input KV cache, either
"HND"or"NHD".k_global_sf (torch.Tensor, optional) – Global scale factor for K (float32 scalar tensor). When
None, auto-computed asFLOAT8_E4M3_MAX / k_amax.v_global_sf (torch.Tensor, optional) – Global scale factor for V (float32 scalar tensor). When
None, auto-computed asFLOAT8_E4M3_MAX / v_amax.
- Returns:
(kv_cache_fp4, kv_cache_sf, k_global_scale, v_global_scale)wherekv_cache_fp4is(k_fp4, v_fp4)in the same layout as the input withhead_dimreplaced byhead_dim // 2, dtypeuint8;kv_cache_sfis(k_scales, v_scales)(k_scaleskeeps the linear input layout,v_scalesuses TRT-LLM’s 4-token interleaved layout, both withhead_dimreplaced byhead_dim // 16, dtypefloat8_e4m3fn); and the two trailing floats are1 / k_global_sfand1 / v_global_sf.- Return type:
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], float, float]