flashinfer.gemm#
This module provides a set of GEMM operations.
FP8 Batch GEMM#
|
BMM FP8 |
Grouped GEMM#
- class flashinfer.gemm.SegmentGEMMWrapper(float_workspace_buffer: torch.Tensor)#
Wrapper for segment GEMM kernels.
Example
>>> import torch >>> from flashinfer import SegmentGEMMWrapper >>> # create a 1MB workspace buffer >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda") >>> segment_gemm = SegmentGEMMWrapper(workspace_buffer) >>> seq_lens = torch.tensor([1, 2, 3, 4], dtype=torch.int64, device="cuda") >>> # create packed input tensor (10 = 1 + 2 + 3 + 4) >>> x = torch.randn(10, 128, device="cuda", dtype=torch.float16) >>> # create weight tensor with 4 weights, each with 128 input and 256 output channels, column major >>> weights = torch.randn(4, 256, 128, device="cuda", dtype=torch.float16) >>> # compute the segment GEMM >>> y = segment_gemm.run(x, weights, 4, True, seg_lens=seq_lens) >>> y.shape torch.Size([10, 256]) >>> y_ref_0 = torch.matmul(x[:1], weights[0].t()) >>> torch.allclose(y[:1], y_ref_0) True >>> y_ref_1 = torch.matmul(x[1:3], weights[1].t()) >>> torch.allclose(y[1:3], y_ref_1) True >>> y_ref_2 = torch.matmul(x[3:6], weights[2].t()) >>> torch.allclose(y[3:6], y_ref_2) True >>> y_ref_3 = torch.matmul(x[6:], weights[3].t()) >>> torch.allclose(y[6:], y_ref_3) True >>> >>> # another example with weight indices >>> weight_indices = torch.tensor([0, 1, 0, 1], dtype=torch.int64, device="cuda") >>> y = segment_gemm.run(x, weights, 4, True, seg_lens=seq_lens, weight_indices=weight_indices) >>> y.shape torch.Size([10, 256]) >>> y_ref_0 = torch.matmul(x[:1], weights[0].t()) >>> torch.allclose(y[:1], y_ref_0) True >>> y_ref_1 = torch.matmul(x[1:3], weights[1].t()) >>> torch.allclose(y[1:3], y_ref_1) True >>> y_ref_2 = torch.matmul(x[3:6], weights[0].t()) >>> torch.allclose(y[3:6], y_ref_2) True >>> y_ref_3 = torch.matmul(x[6:], weights[1].t()) >>> torch.allclose(y[6:], y_ref_3) True
- __init__(float_workspace_buffer: torch.Tensor) None #
Initialize the wrapper.
- Parameters:
float_workspace_buffer (torch.Tensor) – The workspace buffer for the kernels, we use it for storing intermediate results in cutlass segment GEMM kernels. Encouraged size is 128MB.
- reset_workspace_buffer(float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor) None #
Reset the workspace buffer.
- Parameters:
float_workspace_buffer (torch.Tensor) – The new float workspace buffer for the kernels.
int_workspace_buffer (torch.Tensor) – The new int workspace buffer for the kernels.
- run(x: torch.Tensor, weights: torch.Tensor, batch_size: int, weight_column_major: bool, seg_lens: torch.Tensor | None = None, seg_indptr: torch.Tensor | None = None, weight_indices: torch.Tensor | None = None) torch.Tensor #
Run the segment GEMM kernel.
Compute the matrix multiplication between a batch of input tensor (with variable number of rows, but fixed number of columns) and a batch of weight tensor with fixed number of rows and columns:
\[y[i] = x[i] \times W[i]\]if
weight_indices
is provided, we will select the weight tensor based on the indices in theweight_indices
tensor:\[y[i] = x[i] \times W[\text{weight_indices}[i]]\]We use Ragged Tensor to represent the input tensor
x
and the output tensory
, and each x[i] is a segment of the concatenated tensor. Please see Ragged Tensor tutorial for more details. We use aseg_len
orseg_indptr
tensor (either would work) to indicate the start and end of each segment, where theseg_indptr
is the cumulative sum of theseg_lens
tensor (with an additional 0 at the beginning):\[\text{seg_indptr}[i] = \sum_{j=0}^{i-1} \text{seg_lens}[j], \quad \text{seg_indptr}[0] = 0\]- If
seg_lens
is provided, thenx
has shape(sum(seg_lens), d_in)
andy
has shape (sum(seg_lens), d_out)
, whered_in
is the number of columns of the input tensor andd_out
is the number of columns of the output tensor.
- If
- If
seg_indptr
is provided, thenx
has shape(seg_indptr[-1], d_in)
andy
has shape (seg_indptr[-1], d_out)
.
- If
- Parameters:
x (torch.Tensor) – The input tensor with shape
(sum(seg_lens), d_in)
.weights (torch.Tensor) – The 3D weight tensor with shape
(num_weights, d_in, d_out)
ifweight_column_major
isFalse
, or(num_weights, d_out, d_in)
ifweight_column_major
isTrue
.batch_size (int) – The number of segments.
weight_column_major (bool) – Whether the weight tensor is column major.
seg_lens (Optional[torch.Tensor]) – The length of each segment, with shape
(batch_size,)
, expects a 1D tensor of dtypetorch.int64
.seg_indptr (Optional[torch.Tensor]) – The indptr of the segments, with shape
(batch_size + 1,)
, expects a 1D tensor of dtypetorch.int64
. If this is provided, thenseg_lens
will be ignored, otherwiseseg_indptr
will be computed internally fromseg_lens
.weight_indices (Optional[torch.Tensor]) – The indices of the weight tensor to be selected for each segment, with shape
(batch_size,)
. Expects a 1D tensor of dtypetorch.int64
. If this is provided, then the weight tensor will be selected based on the indices in this tensor.
- Returns:
The output tensor with shape
(sum(seg_lens), d_out)
.- Return type:
torch.Tensor