flashinfer.gemm.gemm_fp8_nt_groupwise¶
- flashinfer.gemm.gemm_fp8_nt_groupwise(a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), out: torch.Tensor | None = None, out_dtype: torch.dtype | None = None) torch.Tensor ¶
Performs matrix multiplication with FP8 data types using groupwise scaling.
This function implements a GEMM operation that allows for fine-grained control over scale granularity across different dimensions. Currently only supported on NVIDIA Blackwell architecture.
- Parameters:
a (torch.Tensor) – Row-major input tensor shape (m, k), fp8 e4m3 or fp8 e5m2.
b (torch.Tensor) – Column-major input tensor shape (n, k), fp8 e4m3 or fp8 e5m2.
a_scale (torch.Tensor) – Column-major scale tensor for a, shape (ceil_div(k, k_granularity), ceil_div(m, m_granularity)).
b_scale (torch.Tensor) – Row-major scale tensor for b, shape (ceil_div(k, k_granularity), ceil_div(n, n_granularity)).
scale_granularity_mnk (Tuple[int, int, int]) – The granularity of the scale tensor, (m_granularity, n_granularity, k_granularity).
out (Optional[torch.Tensor]) – Output tensor, shape (m, n). If not specified, we will create an output tensor explicitly.
out_dtype (Optional[torch.dtype]) – If out is not specified, we will create an output tensor with this dtype. Defaults to
torch.bfloat16
.
- Returns:
out – Output tensor, shape (m, n).
- Return type:
torch.Tensor
Notes
If
m
is not a multiple of 4, we will padm
to the next multiple of 4 to accommodate the kernel’s requirement.