flashinfer.gemm.group_gemm_fp8_nt_groupwise¶
- flashinfer.gemm.group_gemm_fp8_nt_groupwise(a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, m_indptr: torch.Tensor, scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), mma_sm: int = 1, out: torch.Tensor | None = None, out_dtype: torch.dtype | None = None) torch.Tensor ¶
Perform group GEMM with FP8 data types using groupwise scaling. Currently only supported on NVIDIA Blackwell architecture.
- Parameters:
a (torch.Tensor) – Row-major input tensor shape
(cum_m, k)
, data type istorch.float8_e4m3fn
ortorch.float8_e5m2
.cum_m
is the cumulative sum of the segment lengths.b (torch.Tensor) – Column-major input tensor shape
(batch_size, n, k)
, data type istorch.float8_e4m3fn
ortorch.float8_e5m2
.a_scale (torch.Tensor) – Column-major scale tensor for a, shape
(k // block_size, cum_m)
.b_scale (torch.Tensor) – Row-major scale tensor for b, shape
(batch_size, k // block_size, n // block_size)
.m_indptr (torch.Tensor) – The indptr of the segment lengths, shape
(batch_size + 1,)
. Element element inm_indptr
must be a multiple of 4.scale_granularity_mnk (Tuple[int, int, int]) – The granularity of the scale tensor, (m_granularity, n_granularity, k_granularity).
mma_sm (int) – How many SMs to use for the MMA operation, must be 1 or 2. 2 is faster when number of rows (M) per group is large (>= 256).
out (Optional[torch.Tensor]) – The output tensor, shape
(cum_m, n)
. If not specified, we will create an output tensor explicitly.out_dtype (Optional[torch.dtype]) – The data type of the output tensor.
- Returns:
out – The output tensor, shape
(cum_m, n)
.- Return type:
torch.Tensor