flashinfer.gemm.group_gemm_mxfp4_nt_groupwise

flashinfer.gemm.group_gemm_mxfp4_nt_groupwise(a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, m_indptr: torch.Tensor, mma_sm: int = 1, tile_m: int = 128, tile_n: int = 128, tile_k: int = 128, swap_ab: bool = True, out: torch.Tensor | None = None, out_dtype: torch.dtype | None = None) torch.Tensor

Perform group GEMM with MXFP4 data types using groupwise scaling. Currently only supported on NVIDIA Blackwell architecture.

Parameters:
  • a (torch.Tensor) – Row-major input tensor, shape (cum_m, k), data type is torch.float8_e4m3fn or torch.float8_e5m2. cum_m is the cumulative sum of the segment lengths.

  • b (torch.Tensor) – Column-major input tensor, shape (batch_size, n, k // 2), data type is torch.uint8.

  • a_scale (torch.Tensor) – Column-major scale tensor for a, shape (cum_m_padded, k // 32), data type is torch.uint8.

  • b_scale (torch.Tensor) – Row-major scale tensor for b, shape (batch_size, n_padded, k // 32), data type is torch.uint8.

  • m_indptr (torch.Tensor) – The indptr of the segment lengths, shape (batch_size + 1,), data type is torch.int32. Element element in m_indptr must be a multiple of 4.

  • mma_sm (int) – How many SMs to use for the MMA operation, must be 1 or 2. 2 is faster when number of rows (M) per group is large (>= 256).

  • tile_m (int) – The tile size for the M dimension, must be 128.

  • tile_n (int) – The tile size for the N dimension, must be 64, 128, 192, or 256.

  • tile_k (int) – The tile size for the K dimension, must be 128 or 256.

  • swap_ab (bool) – Whether to swap the A and B tensors.

  • out (Optional[torch.Tensor]) – The output tensor, shape (cum_m, n). If not specified, we will create an output tensor explicitly.

  • out_dtype (Optional[torch.dtype]) – The data type of the output tensor, must be torch.bfloat16 or torch.float16.

Returns:

out – The output tensor, shape (cum_m, n).

Return type:

torch.Tensor

Notes

Each value in m_indptr should be padded to a multiple of 4 before calling this function, to accommodate the kernel’s requirement.