flashinfer.fp4_quantization.nvfp4_kv_dequantize

flashinfer.fp4_quantization.nvfp4_kv_dequantize(fp4_data: Tensor, block_scales: Tensor, global_scale: Tensor, output_dtype: dtype = torch.bfloat16) Tensor

GPU dequantization of NVFP4 KV cache data with linear block scale layout.

Requires SM80+.

Parameters:
  • fp4_data (torch.Tensor) – Packed FP4 data of shape [M, K/2] with dtype uint8.

  • block_scales (torch.Tensor) – Per-block FP8 E4M3 scales of shape [M, K/16] with dtype uint8.

  • global_scale (torch.Tensor) – Global scale factor of shape [1] with dtype float32, on the same CUDA device as fp4_data.

  • output_dtype (torch.dtype) – Output dtype, either torch.bfloat16 or torch.float16.

Returns:

Dequantized tensor of shape [M, K] with the specified output dtype.

Return type:

torch.Tensor