flashinfer.gemm.gemm_fp8_nt_groupwise¶
- flashinfer.gemm.gemm_fp8_nt_groupwise(a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, scale_major_mode: Literal['MN', 'K'] = 'MN', mma_sm: int = 1, scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), out: torch.Tensor | None = None, out_dtype: torch.dtype | None = None) torch.Tensor ¶
Performs matrix multiplication with FP8 data types using groupwise scaling.
This function implements a GEMM operation that allows for fine-grained control over scale granularity across different dimensions. Currently only supported on NVIDIA Blackwell architecture.
- Parameters:
a (torch.Tensor) – Row-major input tensor shape (m, k), fp8 e4m3 or fp8 e5m2.
b (torch.Tensor) – Column-major input tensor shape (n, k), fp8 e4m3 or fp8 e5m2.
a_scale (torch.Tensor) – Column-major scale tensor for a, shape
(m, k // block_size)
if scale_major_mode isK
or shape(k // block_size, m)
if scale_major_mode isMN
b_scale (torch.Tensor) – Row-major scale tensor for b, shape
(n // block_size, k // block_size)
if scale_major_k isK
or shape(k // block_size, n // block_size)
if scale_major_mode isMN
scale_granularity_mnk (Tuple[int, int, int]) – The granularity of the scale tensor, (m_granularity, n_granularity, k_granularity).
scale_major_mode (Literal["MN", "K"]) – The layout mode of scale tensor, MN for MN-major scale with shape of
(k // block_size, *)
and K for K-major scale with shape of(*, k // block_size)
mma_sm (int) – How many SMs to use for the MMA operation, must be 1 or 2. 2 is faster when number of rows (M) per group is large (>= 256).
out (Optional[torch.Tensor]) – Output tensor, shape (m, n). If not specified, we will create an output tensor explicitly.
out_dtype (Optional[torch.dtype]) – If out is not specified, we will create an output tensor with this dtype. Defaults to
torch.bfloat16
.
- Returns:
out – Output tensor, shape (m, n).
- Return type:
torch.Tensor
Notes
The
m
should be padded to a multiple of 4 before calling this function, to accommodate the kernel’s requirement.