flashinfer.gemm.gemm_fp8_nt_groupwise

flashinfer.gemm.gemm_fp8_nt_groupwise(a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), out: torch.Tensor | None = None, out_dtype: torch.dtype | None = None) torch.Tensor

Performs matrix multiplication with FP8 data types using groupwise scaling.

This function implements a GEMM operation that allows for fine-grained control over scale granularity across different dimensions. Currently only supported on NVIDIA Blackwell architecture.

Parameters:
  • a (torch.Tensor) – Row-major input tensor shape (m, k), fp8 e4m3 or fp8 e5m2.

  • b (torch.Tensor) – Column-major input tensor shape (n, k), fp8 e4m3 or fp8 e5m2.

  • a_scale (torch.Tensor) – Column-major scale tensor for a, shape (ceil_div(k, k_granularity), ceil_div(m, m_granularity)).

  • b_scale (torch.Tensor) – Row-major scale tensor for b, shape (ceil_div(k, k_granularity), ceil_div(n, n_granularity)).

  • scale_granularity_mnk (Tuple[int, int, int]) – The granularity of the scale tensor, (m_granularity, n_granularity, k_granularity).

  • out (Optional[torch.Tensor]) – Output tensor, shape (m, n). If not specified, we will create an output tensor explicitly.

  • out_dtype (Optional[torch.dtype]) – If out is not specified, we will create an output tensor with this dtype. Defaults to torch.bfloat16.

Returns:

out – Output tensor, shape (m, n).

Return type:

torch.Tensor

Notes

If m is not a multiple of 4, we will pad m to the next multiple of 4 to accommodate the kernel’s requirement.