flashinfer.gemm.group_gemm_nvfp4_nt_groupwise

flashinfer.gemm.group_gemm_nvfp4_nt_groupwise(a: Tensor, b: Tensor, a_scale: Tensor, b_scale: Tensor, m_indptr: Tensor, alpha: Tensor | None = None, tile_m: int = 128, tile_n: int = 128, tile_k: int = 128, swap_ab: bool = True, out: Tensor | None = None, out_dtype: dtype | None = None) Tensor

Perform group GEMM with NVFP4 data types using groupwise scaling. Currently only implemented on NVIDIA Blackwell Geforce, and DGX Spark architectures.

Parameters:
  • a (torch.Tensor) – Row-major input tensor, shape (cum_m, k // 2), data type is torch.uint8 (packed NVFP4). cum_m is the cumulative sum of the segment lengths.

  • b (torch.Tensor) – Column-major input tensor, shape (batch_size, n, k // 2), data type is torch.uint8.

  • a_scale (torch.Tensor) – Column-major scale tensor for a, shape (cum_m_padded, k // 16), data type is torch.uint8.

  • b_scale (torch.Tensor) – Row-major scale tensor for b, shape (batch_size, n_padded, k // 16), data type is torch.uint8.

  • m_indptr (torch.Tensor) – The indptr of the segment lengths, shape (batch_size + 1,), data type is torch.int32. Element element in m_indptr must be a multiple of 4.

  • alpha (Optional[torch.Tensor] = None, # (batch_size, )) – The alpha tensor, shape (batch_size, ), data type is torch.float32.

  • tile_m (int) – The tile size for the M dimension, must be 128.

  • tile_n (int) – The tile size for the N dimension, must be 32, 64, or 128.

  • tile_k (int) – The tile size for the K dimension, must be 128 or 256.

  • swap_ab (bool) – Whether to compute Output^T = Weight^T Activation^T instead of Output = Activation Weight. Defaults to True.

  • out (Optional[torch.Tensor]) – The output tensor, shape (cum_m, n). If not specified, we will create an output tensor explicitly.

  • out_dtype (Optional[torch.dtype]) – The data type of the output tensor, must be torch.bfloat16 or torch.float16.

Returns:

out – The output tensor, shape (cum_m, n).

Return type:

torch.Tensor

Notes

Each value in m_indptr should be padded to a multiple of 4 before calling this function, to accommodate the kernel’s requirement.