flashinfer.gemm.group_gemm_fp8_nt_groupwise

flashinfer.gemm.group_gemm_fp8_nt_groupwise(a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, m_indptr: torch.Tensor, scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), mma_sm: int = 1, out: torch.Tensor | None = None, out_dtype: torch.dtype | None = None) torch.Tensor

Perform group GEMM with FP8 data types using groupwise scaling. Currently only supported on NVIDIA Blackwell architecture.

Parameters:
  • a (torch.Tensor) – Row-major input tensor shape (cum_m, k), data type is torch.float8_e4m3fn or torch.float8_e5m2. cum_m is the cumulative sum of the segment lengths.

  • b (torch.Tensor) – Column-major input tensor shape (batch_size, n, k), data type is torch.float8_e4m3fn or torch.float8_e5m2.

  • a_scale (torch.Tensor) – Column-major scale tensor for a, shape (k // block_size, cum_m).

  • b_scale (torch.Tensor) – Row-major scale tensor for b, shape (batch_size, k // block_size, n // block_size).

  • m_indptr (torch.Tensor) – The indptr of the segment lengths, shape (batch_size + 1,). Element element in m_indptr must be a multiple of 4.

  • scale_granularity_mnk (Tuple[int, int, int]) – The granularity of the scale tensor, (m_granularity, n_granularity, k_granularity).

  • mma_sm (int) – How many SMs to use for the MMA operation, must be 1 or 2. 2 is faster when number of rows (M) per group is large (>= 256).

  • out (Optional[torch.Tensor]) – The output tensor, shape (cum_m, n). If not specified, we will create an output tensor explicitly.

  • out_dtype (Optional[torch.dtype]) – The data type of the output tensor.

Returns:

out – The output tensor, shape (cum_m, n).

Return type:

torch.Tensor