flashinfer.gemm.group_gemm_mxfp4_nt_groupwise¶
- flashinfer.gemm.group_gemm_mxfp4_nt_groupwise(a: Tensor, b: Tensor, a_scale: Tensor, b_scale: Tensor, m_indptr: Tensor, mma_sm: int = 1, tile_m: int = 128, tile_n: int = 128, tile_k: int = 128, swap_ab: bool = True, out: Tensor | None = None, out_dtype: dtype | None = None) Tensor¶
 Perform group GEMM with MXFP4 data types using groupwise scaling. Currently only supported on NVIDIA Blackwell architecture.
- Parameters:
 a (torch.Tensor) – Row-major input tensor, shape
(cum_m, k), data type istorch.float8_e4m3fnortorch.float8_e5m2.cum_mis the cumulative sum of the segment lengths.b (torch.Tensor) – Column-major input tensor, shape
(batch_size, n, k // 2), data type istorch.uint8.a_scale (torch.Tensor) – Column-major scale tensor for a, shape
(cum_m_padded, k // 32), data type istorch.uint8.b_scale (torch.Tensor) – Row-major scale tensor for b, shape
(batch_size, n_padded, k // 32), data type istorch.uint8.m_indptr (torch.Tensor) – The indptr of the segment lengths, shape
(batch_size + 1,), data type istorch.int32. Element element inm_indptrmust be a multiple of 4.mma_sm (int) – How many SMs to use for the MMA operation, must be 1 or 2. 2 is faster when number of rows (M) per group is large (>= 256).
tile_m (int) – The tile size for the M dimension, must be 128.
tile_n (int) – The tile size for the N dimension, must be 64, 128, 192, or 256.
tile_k (int) – The tile size for the K dimension, must be 128 or 256.
swap_ab (bool) – Whether to swap the A and B tensors.
out (Optional[torch.Tensor]) – The output tensor, shape
(cum_m, n). If not specified, we will create an output tensor explicitly.out_dtype (Optional[torch.dtype]) – The data type of the output tensor, must be
torch.bfloat16ortorch.float16.
- Returns:
 out – The output tensor, shape
(cum_m, n).- Return type:
 torch.Tensor
Notes
Each value in
m_indptrshould be padded to a multiple of 4 before calling this function, to accommodate the kernel’s requirement.