flashinfer.gemm¶
This module provides a set of GEMM operations.
FP4 GEMM¶
| 
 | MM FP4 | 
FP8 GEMM¶
| 
 | BMM FP8 | 
| 
 | Performs matrix multiplication with FP8 data types using groupwise scaling. | 
| 
 | Perform group GEMM with FP8 data types using groupwise scaling. | 
| 
 | Perform grouped matrix multiplication with FP8 data types using DeepGEMM backend. | 
| 
 | Perform batch matrix multiplication with FP8 data types using DeepGEMM backend. | 
Mixed Precision GEMM (fp8 x fp4)¶
| 
 | Perform group GEMM with MXFP4 data types using groupwise scaling. | 
Grouped GEMM (Ampere/Hopper)¶
- class flashinfer.gemm.SegmentGEMMWrapper(float_workspace_buffer: Tensor, backend: str = 'auto')¶
- Wrapper for segment GEMM kernels. - Example - >>> import torch >>> from flashinfer import SegmentGEMMWrapper >>> # create a 1MB workspace buffer >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda") >>> segment_gemm = SegmentGEMMWrapper(workspace_buffer) >>> seq_lens = torch.tensor([1, 2, 3, 4], dtype=torch.int64, device="cuda") >>> # create packed input tensor (10 = 1 + 2 + 3 + 4) >>> x = torch.randn(10, 128, device="cuda", dtype=torch.float16) >>> # create weight tensor with 4 weights, each with 128 input and 256 output channels, column major >>> weights = torch.randn(4, 256, 128, device="cuda", dtype=torch.float16) >>> # compute the segment GEMM >>> y = segment_gemm.run(x, weights, 4, True, seg_lens=seq_lens) >>> y.shape torch.Size([10, 256]) >>> y_ref_0 = torch.matmul(x[:1], weights[0].t()) >>> torch.allclose(y[:1], y_ref_0) True >>> y_ref_1 = torch.matmul(x[1:3], weights[1].t()) >>> torch.allclose(y[1:3], y_ref_1) True >>> y_ref_2 = torch.matmul(x[3:6], weights[2].t()) >>> torch.allclose(y[3:6], y_ref_2) True >>> y_ref_3 = torch.matmul(x[6:], weights[3].t()) >>> torch.allclose(y[6:], y_ref_3) True >>> >>> # another example with weight indices >>> weight_indices = torch.tensor([0, 1, 0, 1], dtype=torch.int64, device="cuda") >>> y = segment_gemm.run(x, weights, 4, True, seg_lens=seq_lens, weight_indices=weight_indices) >>> y.shape torch.Size([10, 256]) >>> y_ref_0 = torch.matmul(x[:1], weights[0].t()) >>> torch.allclose(y[:1], y_ref_0) True >>> y_ref_1 = torch.matmul(x[1:3], weights[1].t()) >>> torch.allclose(y[1:3], y_ref_1) True >>> y_ref_2 = torch.matmul(x[3:6], weights[0].t()) >>> torch.allclose(y[3:6], y_ref_2) True >>> y_ref_3 = torch.matmul(x[6:], weights[1].t()) >>> torch.allclose(y[6:], y_ref_3) True - __init__(float_workspace_buffer: Tensor, backend: str = 'auto') None¶
- Initialize the wrapper. - Parameters:
- float_workspace_buffer (torch.Tensor) – The workspace buffer for the kernels, we use it for storing intermediate results in cutlass segment GEMM kernels. Encouraged size is 128MB. 
 
 - reset_workspace_buffer(float_workspace_buffer: Tensor, int_workspace_buffer: Tensor) None¶
- Reset the workspace buffer. - Parameters:
- float_workspace_buffer (torch.Tensor) – The new float workspace buffer for the kernels. 
- int_workspace_buffer (torch.Tensor) – The new int workspace buffer for the kernels. 
 
 
 - run(x: Tensor, weights: Tensor, batch_size: int, weight_column_major: bool, out: Tensor | None = None, seg_lens: Tensor | None = None, seg_indptr: Tensor | None = None, weight_indices: Tensor | None = None) Tensor¶
- Run the segment GEMM kernel. - Compute the matrix multiplication between a batch of input tensor (with variable number of rows, but fixed number of columns) and a batch of weight tensor with fixed number of rows and columns: \[y[i] = x[i] \times W[i]\]- if - weight_indicesis provided, we will select the weight tensor based on the indices in the- weight_indicestensor:\[y[i] = x[i] \times W[\text{weight_indices}[i]]\]- We use Ragged Tensor to represent the input tensor - xand the output tensor- y, and each x[i] is a segment of the concatenated tensor. Please see Ragged Tensor tutorial for more details. We use a- seg_lenor- seg_indptrtensor (either would work) to indicate the start and end of each segment, where the- seg_indptris the cumulative sum of the- seg_lenstensor (with an additional 0 at the beginning):\[\text{seg_indptr}[i] = \sum_{j=0}^{i-1} \text{seg_lens}[j], \quad \text{seg_indptr}[0] = 0\]- If seg_lensis provided, thenxhas shape(sum(seg_lens), d_in)andyhas shape
- (sum(seg_lens), d_out), where- d_inis the number of columns of the input tensor and- d_outis the number of columns of the output tensor.
 
- If 
- If seg_indptris provided, thenxhas shape(seg_indptr[-1], d_in)andyhas shape
- (seg_indptr[-1], d_out).
 
- If 
 - Parameters:
- x (torch.Tensor) – The input tensor with shape - (sum(seg_lens), d_in).
- weights (torch.Tensor) – The 3D weight tensor with shape - (num_weights, d_in, d_out)if- weight_column_majoris- False, or- (num_weights, d_out, d_in)if- weight_column_majoris- True.
- batch_size (int) – The number of segments. 
- weight_column_major (bool) – Whether the weight tensor is column major. 
- out (Optional[torch.Tensor]) – The output tensor, with shape - (sum(seg_lens), d_out). If not provided, a new tensor will be created internally.
- seg_lens (Optional[torch.Tensor]) – The length of each segment, with shape - (batch_size,), expects a 1D tensor of dtype- torch.int64.
- seg_indptr (Optional[torch.Tensor]) – The indptr of the segments, with shape - (batch_size + 1,), expects a 1D tensor of dtype- torch.int64. If this is provided, then- seg_lenswill be ignored, otherwise- seg_indptrwill be computed internally from- seg_lens.
- weight_indices (Optional[torch.Tensor]) – The indices of the weight tensor to be selected for each segment, with shape - (batch_size,). Expects a 1D tensor of dtype- torch.int64. If this is provided, then the weight tensor will be selected based on the indices in this tensor.
 
- Returns:
- The output tensor with shape - (sum(seg_lens), d_out).
- Return type:
- torch.Tensor