flashinfer.gemm.group_deepgemm_fp8_nt_groupwise

flashinfer.gemm.group_deepgemm_fp8_nt_groupwise(a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, m_indices: torch.Tensor, scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), out: torch.Tensor | None = None, out_dtype: torch.dtype | None = None)

Perform grouped matrix multiplication with FP8 data types using DeepGEMM backend.

This function performs a grouped GEMM operation where each group in tensor b is multiplied with the corresponding rows in tensor a. The grouping is determined by the m_indices tensor, which specifies which group each row belongs to. This is particularly useful for scenarios like mixture of experts (MoE) where different tokens are routed to different experts.

The operation can be conceptualized as:

>>> for i in range(num_groups):
>>>    row_slice = slice(i * m_per_group, (i + 1) * m_per_group)
>>>    output[row_slice] = a[row_slice] @ b[i].T

Currently only supported on NVIDIA Blackwell (SM100) architecture.

Parameters:
  • a (torch.Tensor) – Input tensor A of shape (m, k) with FP8 data type (torch.float8_e4m3fn). This tensor contains all rows that will be multiplied with different groups in b.

  • b (torch.Tensor) – Input tensor B of shape (batch_size, n, k) with FP8 data type (torch.float8_e4m3fn). Each slice b[i] represents a different group/expert that will be multiplied with the corresponding rows in a.

  • a_scale (torch.Tensor) – Scaling factors for tensor a of shape (m, k // block_size) with torch.float32 dtype. These are typically generated from per-token quantization of the original float32 tensor.

  • b_scale (torch.Tensor) – Scaling factors for tensor b of shape (batch_size, n // block_size, k // block_size) with torch.float32 dtype. These are typically generated from per-block quantization of the original float32 tensor for each group.

  • m_indices (torch.Tensor) – Group assignment tensor of shape (m,) with torch.int32 dtype. Each element specifies which group (index into b) the corresponding row in a belongs to. For example, if m_indices[i] = j, then row i in a will be multiplied with group j in b.

  • scale_granularity_mnk (Tuple[int, int, int], optional) – The granularity of the scaling factors as (m_granularity, n_granularity, k_granularity). Default is (1, 128, 128) which means per-token scaling for a and 128x128 block scaling for b.

  • out (Optional[torch.Tensor], optional) – Pre-allocated output tensor of shape (m, n). If not provided, a new tensor will be created.

  • out_dtype (Optional[torch.dtype], optional) – Data type of the output tensor. If out is provided, this parameter is ignored. Default is torch.bfloat16.

Returns:

Output tensor of shape (m, n) containing the results of the grouped matrix multiplication.

Return type:

torch.Tensor

Examples

>>> import torch
>>> from flashinfer.gemm import group_deepgemm_fp8_nt_groupwise
>>> from flashinfer.utils import per_token_cast_to_fp8, per_block_cast_to_fp8
>>>
>>> # Setup: 2 groups, 128 tokens per group, 4096 hidden size, 2048 expert size
>>> m_per_group, n, k = 128, 2048, 4096
>>> group_size = 2
>>> m = m_per_group * group_size
>>>
>>> # Create float32 inputs
>>> a_f32 = torch.randn(m, k, device="cuda", dtype=torch.float32)
>>> b_f32 = torch.randn(group_size, n, k, device="cuda", dtype=torch.float32)
>>>
>>> # Quantize to FP8 with appropriate scaling
>>> a_fp8, a_scale = per_token_cast_to_fp8(a_f32)
>>> b_fp8 = torch.empty_like(b_f32, dtype=torch.float8_e4m3fn)
>>> b_scale = torch.empty((group_size, n // 128, k // 128), device="cuda", dtype=torch.float32)
>>> for i in range(group_size):
...     b_fp8[i], b_scale[i] = per_block_cast_to_fp8(b_f32[i])
>>>
>>> # Create group assignment
>>> m_indices = torch.empty(m, device="cuda", dtype=torch.int32)
>>> for i in range(group_size):
...     row_slice = slice(i * m_per_group, (i + 1) * m_per_group)
...     m_indices[row_slice] = i
>>>
>>> # Perform grouped GEMM
>>> result = group_deepgemm_fp8_nt_groupwise(
...     a_fp8, b_fp8, a_scale, b_scale, m_indices, out_dtype=torch.bfloat16
... )
>>> print(result.shape)  # torch.Size([256, 2048])

Notes

  • This function requires NVIDIA Blackwell (SM100) architecture

  • The scaling factors should be generated using appropriate quantization functions like per_token_cast_to_fp8 for a and per_block_cast_to_fp8 for b

  • The function internally uses the DeepGEMM backend for optimized FP8 computation

  • All input tensors must be on the same CUDA device

  • The block size for scaling is determined by the scale_granularity_mnk parameter