flashinfer.gemm.gemm_fp8_nt_groupwise

flashinfer.gemm.gemm_fp8_nt_groupwise(a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, scale_major_mode: Literal['MN', 'K'] = 'MN', mma_sm: int = 1, scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), out: torch.Tensor | None = None, out_dtype: torch.dtype | None = None) torch.Tensor

Performs matrix multiplication with FP8 data types using groupwise scaling.

This function implements a GEMM operation that allows for fine-grained control over scale granularity across different dimensions. Currently only supported on NVIDIA Blackwell architecture.

Parameters:
  • a (torch.Tensor) – Row-major input tensor shape (m, k), fp8 e4m3 or fp8 e5m2.

  • b (torch.Tensor) – Column-major input tensor shape (n, k), fp8 e4m3 or fp8 e5m2.

  • a_scale (torch.Tensor) – Column-major scale tensor for a, shape (m, k // block_size) if scale_major_mode is K or shape (k // block_size, m) if scale_major_mode is MN

  • b_scale (torch.Tensor) – Row-major scale tensor for b, shape (n // block_size, k // block_size) if scale_major_k is K or shape (k // block_size, n // block_size) if scale_major_mode is MN

  • scale_granularity_mnk (Tuple[int, int, int]) – The granularity of the scale tensor, (m_granularity, n_granularity, k_granularity).

  • scale_major_mode (Literal["MN", "K"]) – The layout mode of scale tensor, MN for MN-major scale with shape of (k // block_size, *) and K for K-major scale with shape of (*, k // block_size)

  • mma_sm (int) – How many SMs to use for the MMA operation, must be 1 or 2. 2 is faster when number of rows (M) per group is large (>= 256).

  • out (Optional[torch.Tensor]) – Output tensor, shape (m, n). If not specified, we will create an output tensor explicitly.

  • out_dtype (Optional[torch.dtype]) – If out is not specified, we will create an output tensor with this dtype. Defaults to torch.bfloat16.

Returns:

out – Output tensor, shape (m, n).

Return type:

torch.Tensor

Notes

The m should be padded to a multiple of 4 before calling this function, to accommodate the kernel’s requirement.