flashinfer.fp4_quantization.e2m1_and_ufp8sf_scale_to_float¶
- flashinfer.fp4_quantization.e2m1_and_ufp8sf_scale_to_float(e2m1_tensor: torch.Tensor, ufp8_scale_tensor: torch.Tensor, global_scale_tensor: torch.Tensor | None = None, sf_vec_size: int = 16, ufp8_type: int = 1, is_sf_swizzled_layout: bool = True) torch.Tensor ¶
Convert E2M1 format tensor and UFP8 scale factors to float tensor.
This function performs dequantization by converting a packed FP4 tensor in E2M1 format back to float values using the associated UFP8 scale factors and global scale.
- Parameters:
e2m1_tensor (torch.Tensor) – Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8.
ufp8_scale_tensor (torch.Tensor) – Scale factors tensor in UFP8 format with dtype uint8.
global_scale_tensor (torch.Tensor, optional) – Global scale factor of shape [1] and dtype float32.
sf_vec_size (int, optional) – Scale factor vector size. Defaults to 16.
ufp8_type (int, optional) – UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1.
is_sf_swizzled_layout (bool, optional) – Whether scale factors use swizzled layout. Defaults to True.
- Returns:
Dequantized float tensor of shape [M, K] with dtype float32.
- Return type:
torch.Tensor