flashinfer.gemm.batch_deepgemm_fp8_nt_groupwise¶
- flashinfer.gemm.batch_deepgemm_fp8_nt_groupwise(a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, masked_m: torch.Tensor, expected_m: int, scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), out: torch.Tensor | None = None, out_dtype: torch.dtype | None = None)¶
Perform batch matrix multiplication with FP8 data types using DeepGEMM backend.
This function performs a batch GEMM operation where each group in tensor b is multiplied with the corresponding group of rows in tensor a. The results of each group is masked by the masked_m tensor, which specifies which group each row belongs to. This is particularly useful for scenarios like mixture of experts (MoE) where different tokens are routed to different experts.
The operation can be conceptualized as:
>>> for i in range(num_groups): >>> output[i] = a[i][:masked_m[i]] @ b[i][:masked_m[i]].T
Currently only supported on NVIDIA Blackwell (SM100) architecture.
- Parameters:
a (torch.Tensor) – Input tensor A of shape
(batch_size, m, k)
with FP8 data type (torch.float8_e4m3fn
). Each slicea[i]
represents a group of rows that will be multiplied with the corresponding group/expert in b.b (torch.Tensor) – Input tensor B of shape
(batch_size, n, k)
with FP8 data type (torch.float8_e4m3fn
). Each sliceb[i]
represents a different group/expert that will be multiplied with the corresponding rows in a.a_scale (torch.Tensor) – Scaling factors for tensor a of shape
(batch_size, m, k // block_size)
withtorch.float32
dtype. These are typically generated from per-token quantization of the original float32 tensor.b_scale (torch.Tensor) – Scaling factors for tensor b of shape
(batch_size, n // block_size, k // block_size)
withtorch.float32
dtype. These are typically generated from per-block quantization of the original float32 tensor for each group.masked_m (torch.Tensor) – Masking tensor of shape
(batch_size,)
withtorch.int32
dtype. Each element specifies the effective rows to be multiplied in each group. For example, ifmasked_m[i] = j
, then firstj
rows in a[i] will be multiplied with groupi
in b.expected_m (int) – A value hint (which is a value on CPU) for the M expectation of each batch, correctly setting this value may lead to better performance.
scale_granularity_mnk (Tuple[int, int, int], optional) – The granularity of the scaling factors as
(m_granularity, n_granularity, k_granularity)
. Default is(1, 128, 128)
which means per-token scaling for a and 128x128 block scaling for b.out (Optional[torch.Tensor], optional) – Pre-allocated output tensor of shape
(batch_size, m, n)
. If not provided, a new tensor will be created.out_dtype (Optional[torch.dtype], optional) – Data type of the output tensor. If out is provided, this parameter is ignored. Default is
torch.bfloat16
.
- Returns:
Output tensor of shape
(batch_size, m, n)
containing the results of the batch matrix multiplication.- Return type:
torch.Tensor
Examples
>>> import torch >>> from flashinfer.gemm import batch_deepgemm_fp8_nt_groupwise >>> from flashinfer.utils import per_token_cast_to_fp8, per_block_cast_to_fp8 >>> >>> # Setup: 2 groups, 128 tokens per group, 4096 hidden size, 2048 expert size >>> m, n, k = 128, 2048, 4096 >>> group_size = 2 >>> >>> # Create float32 inputs >>> a = torch.rand((group_size, m, k), device="cuda", dtype=torch.float32) >>> b = torch.rand((group_size, n, k), device="cuda", dtype=torch.float32) >>> masked_m = torch.randint(0, m, (group_size,), device="cuda", dtype=torch.int32) >>> a_fp8 = torch.empty_like(a, device="cuda", dtype=torch.float8_e4m3fn) >>> a_scale = torch.empty((group_size, m, k // 128), device="cuda", dtype=torch.float32) >>> b_fp8 = torch.empty_like(b, device="cuda", dtype=torch.float8_e4m3fn) >>> b_scale = torch.empty( ... (group_size, n // 128, k // 128), device="cuda", dtype=torch.float32 >>> ) >>> for i in range(group_size): ... a_fp8[i], a_scale[i] = per_token_cast_to_fp8(a[i]) ... b_fp8[i], b_scale[i] = per_block_cast_to_fp8(b[i]) >>> >>> expected_m = min(int(masked_m.float().mean()) + 1, m) >>> >>> # Perform batch GEMM >>> result = batch_deepgemm_fp8_nt_groupwise( ... a_fp8, b_fp8, a_scale, b_scale, masked_m, expected_m, out_dtype=torch.bfloat16 ... ) >>> print(result.shape) # torch.Size([2, 128, 2048])
Notes
This function requires NVIDIA Blackwell (SM100) architecture
The scaling factors should be generated using appropriate quantization functions like
per_token_cast_to_fp8
for a andper_block_cast_to_fp8
for bThe function internally uses the DeepGEMM backend for optimized FP8 computation
All input tensors must be on the same CUDA device
The block size for scaling is determined by the
scale_granularity_mnk
parameter