flashinfer.fp4_quantization.nvfp4_kv_dequantize¶
- flashinfer.fp4_quantization.nvfp4_kv_dequantize(fp4_data: Tensor, block_scales: Tensor, global_scale: Tensor, output_dtype: dtype = torch.bfloat16) Tensor¶
GPU dequantization of NVFP4 KV cache data with linear block scale layout.
Requires SM80+.
- Parameters:
fp4_data (torch.Tensor) – Packed FP4 data of shape
[M, K/2]with dtype uint8.block_scales (torch.Tensor) – Per-block FP8 E4M3 scales of shape
[M, K/16]with dtype uint8.global_scale (torch.Tensor) – Global scale factor of shape
[1]with dtype float32, on the same CUDA device as fp4_data.output_dtype (torch.dtype) – Output dtype, either
torch.bfloat16ortorch.float16.
- Returns:
Dequantized tensor of shape
[M, K]with the specified output dtype.- Return type:
torch.Tensor