flashinfer.gemm.group_deepgemm_fp8_nt_groupwise¶
- flashinfer.gemm.group_deepgemm_fp8_nt_groupwise(a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, m_indices: torch.Tensor, scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), out: torch.Tensor | None = None, out_dtype: torch.dtype | None = None)¶
Perform grouped matrix multiplication with FP8 data types using DeepGEMM backend.
This function performs a grouped GEMM operation where each group in tensor b is multiplied with the corresponding rows in tensor a. The grouping is determined by the m_indices tensor, which specifies which group each row belongs to. This is particularly useful for scenarios like mixture of experts (MoE) where different tokens are routed to different experts.
The operation can be conceptualized as:
>>> for i in range(num_groups): >>> row_slice = slice(i * m_per_group, (i + 1) * m_per_group) >>> output[row_slice] = a[row_slice] @ b[i].T
Currently only supported on NVIDIA Blackwell (SM100) architecture.
- Parameters:
a (torch.Tensor) – Input tensor A of shape
(m, k)
with FP8 data type (torch.float8_e4m3fn
). This tensor contains all rows that will be multiplied with different groups in b.b (torch.Tensor) – Input tensor B of shape
(batch_size, n, k)
with FP8 data type (torch.float8_e4m3fn
). Each sliceb[i]
represents a different group/expert that will be multiplied with the corresponding rows in a.a_scale (torch.Tensor) – Scaling factors for tensor a of shape
(m, k // block_size)
withtorch.float32
dtype. These are typically generated from per-token quantization of the original float32 tensor.b_scale (torch.Tensor) – Scaling factors for tensor b of shape
(batch_size, n // block_size, k // block_size)
withtorch.float32
dtype. These are typically generated from per-block quantization of the original float32 tensor for each group.m_indices (torch.Tensor) – Group assignment tensor of shape
(m,)
withtorch.int32
dtype. Each element specifies which group (index into b) the corresponding row in a belongs to. For example, ifm_indices[i] = j
, then rowi
in a will be multiplied with groupj
in b.scale_granularity_mnk (Tuple[int, int, int], optional) – The granularity of the scaling factors as
(m_granularity, n_granularity, k_granularity)
. Default is(1, 128, 128)
which means per-token scaling for a and 128x128 block scaling for b.out (Optional[torch.Tensor], optional) – Pre-allocated output tensor of shape
(m, n)
. If not provided, a new tensor will be created.out_dtype (Optional[torch.dtype], optional) – Data type of the output tensor. If out is provided, this parameter is ignored. Default is
torch.bfloat16
.
- Returns:
Output tensor of shape
(m, n)
containing the results of the grouped matrix multiplication.- Return type:
torch.Tensor
Examples
>>> import torch >>> from flashinfer.gemm import group_deepgemm_fp8_nt_groupwise >>> from flashinfer.utils import per_token_cast_to_fp8, per_block_cast_to_fp8 >>> >>> # Setup: 2 groups, 128 tokens per group, 4096 hidden size, 2048 expert size >>> m_per_group, n, k = 128, 2048, 4096 >>> group_size = 2 >>> m = m_per_group * group_size >>> >>> # Create float32 inputs >>> a_f32 = torch.randn(m, k, device="cuda", dtype=torch.float32) >>> b_f32 = torch.randn(group_size, n, k, device="cuda", dtype=torch.float32) >>> >>> # Quantize to FP8 with appropriate scaling >>> a_fp8, a_scale = per_token_cast_to_fp8(a_f32) >>> b_fp8 = torch.empty_like(b_f32, dtype=torch.float8_e4m3fn) >>> b_scale = torch.empty((group_size, n // 128, k // 128), device="cuda", dtype=torch.float32) >>> for i in range(group_size): ... b_fp8[i], b_scale[i] = per_block_cast_to_fp8(b_f32[i]) >>> >>> # Create group assignment >>> m_indices = torch.empty(m, device="cuda", dtype=torch.int32) >>> for i in range(group_size): ... row_slice = slice(i * m_per_group, (i + 1) * m_per_group) ... m_indices[row_slice] = i >>> >>> # Perform grouped GEMM >>> result = group_deepgemm_fp8_nt_groupwise( ... a_fp8, b_fp8, a_scale, b_scale, m_indices, out_dtype=torch.bfloat16 ... ) >>> print(result.shape) # torch.Size([256, 2048])
Notes
This function requires NVIDIA Blackwell (SM100) architecture
The scaling factors should be generated using appropriate quantization functions like
per_token_cast_to_fp8
for a andper_block_cast_to_fp8
for bThe function internally uses the DeepGEMM backend for optimized FP8 computation
All input tensors must be on the same CUDA device
The block size for scaling is determined by the
scale_granularity_mnk
parameter