flashinfer.gemm.group_gemm_fp8_nt_groupwise¶
- flashinfer.gemm.group_gemm_fp8_nt_groupwise(a: Tensor, b: Tensor, a_scale: Tensor, b_scale: Tensor, m_indptr: Tensor, scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), scale_major_mode: Literal['MN', 'K'] = 'MN', mma_sm: int = 1, out: Tensor | None = None, out_dtype: dtype | None = None) Tensor¶
Perform group GEMM with FP8 data types using groupwise scaling. Currently only supported on NVIDIA Blackwell architecture.
- Parameters:
a (torch.Tensor) – Row-major input tensor shape
(cum_m, k), data type istorch.float8_e4m3fnortorch.float8_e5m2.cum_mis the cumulative sum of the segment lengths.b (torch.Tensor) – Column-major input tensor shape
(batch_size, n, k), data type istorch.float8_e4m3fnortorch.float8_e5m2.a_scale (torch.Tensor) – Column-major scale tensor for a, shape
(cum_m, k // block_size)if scale_major_mode isKor shape(k // block_size, cum_m)if scale_major_mode isMN, data type istorch.float32.b_scale (torch.Tensor) – Row-major scale tensor for b, shape
(batch_size, n // block_size, k // block_size)if scale_major_mode isKshape(batch_size, k // block_size, n // block_size)if scale_major_mode isMN, data type istorch.float32.m_indptr (torch.Tensor) – The indptr of the segment lengths, shape
(batch_size + 1,), data type istorch.int32. Element element inm_indptrmust be a multiple of 4.scale_granularity_mnk (Tuple[int, int, int]) – The granularity of the scale tensor, (m_granularity, n_granularity, k_granularity).
scale_major_mode (Literal["MN", "K"]) – The layout mode of scale tensor, MN for MN-major scale with shape of
(k // block_size, *)and K for K-major scale with shape of(*, k // block_size)mma_sm (int) – How many SMs to use for the MMA operation, must be 1 or 2. 2 is faster when number of rows (M) per group is large (>= 256).
out (Optional[torch.Tensor]) – The output tensor, shape
(cum_m, n). If not specified, we will create an output tensor explicitly.out_dtype (Optional[torch.dtype]) – The data type of the output tensor, must be
torch.bfloat16ortorch.float16.
- Returns:
out – The output tensor, shape
(cum_m, n).- Return type:
torch.Tensor
Notes
Each value in
m_indptrshould be padded to a multiple of 4 before calling this function, to accommodate the kernel’s requirement.