flashinfer.gemm#
This module provides a set of GEMM operations.
FP8 Batch GEMM#
|
BMM FP8 |
Grouped GEMM#
- class flashinfer.gemm.SegmentGEMMWrapper(workspace_buffer: torch.Tensor)#
Wrapper for segment GEMM kernels.
Example
>>> import torch >>> from flashinfer import SegmentGEMMWrapper >>> # create a 1MB workspace buffer >>> workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda") >>> segment_gemm = SegmentGEMMWrapper(workspace_buffer) >>> seq_lens = torch.tensor([1, 2, 3, 4], dtype=torch.int64, device="cuda") >>> # create packed input tensor (10 = 1 + 2 + 3 + 4) >>> x = torch.randn(10, 128, device="cuda", dtype=torch.float16) >>> # create weight tensor with 4 weights, each with 128 input and 256 output channels, column major >>> weights = torch.randn(4, 256, 128, device="cuda", dtype=torch.float16) >>> # compute the segment GEMM >>> y = segment_gemm.run(x, weights, 4, True, seg_lens=seq_lens) >>> y.shape torch.Size([10, 256]) >>> y_ref_0 = torch.matmul(x[:1], weights[0].t()) >>> torch.allclose(y[:1], y_ref_0) True >>> y_ref_1 = torch.matmul(x[1:3], weights[1].t()) >>> torch.allclose(y[1:3], y_ref_1) True >>> y_ref_2 = torch.matmul(x[3:6], weights[2].t()) >>> torch.allclose(y[3:6], y_ref_2) True >>> y_ref_3 = torch.matmul(x[6:], weights[3].t()) >>> torch.allclose(y[6:], y_ref_3) True >>> >>> # another example with weight indices >>> weight_indices = torch.tensor([0, 1, 0, 1], dtype=torch.int64, device="cuda") >>> y = segment_gemm.run(x, weights, 4, True, seg_lens=seq_lens, weight_indices=weight_indices) >>> y.shape torch.Size([10, 256]) >>> y_ref_0 = torch.matmul(x[:1], weights[0].t()) >>> torch.allclose(y[:1], y_ref_0) True >>> y_ref_1 = torch.matmul(x[1:3], weights[1].t()) >>> torch.allclose(y[1:3], y_ref_1) True >>> y_ref_2 = torch.matmul(x[3:6], weights[0].t()) >>> torch.allclose(y[3:6], y_ref_2) True >>> y_ref_3 = torch.matmul(x[6:], weights[1].t()) >>> torch.allclose(y[6:], y_ref_3) True
- __init__(workspace_buffer: torch.Tensor) None #
Initialize the wrapper.
- Parameters:
workspace_buffer (torch.Tensor) – The workspace buffer for the kernels, we use it to store the metadata for the segment GEMM whose size is proportional to the number of segments (batch size), 1MB workspace is enough for most cases.
- reset_workspace_buffer(new_workspace_buffer: torch.Tensor) None #
Reset the workspace buffer.
- Parameters:
new_workspace_buffer (torch.Tensor) – The new workspace buffer, the device of the new workspace buffer should be the same as the device of the input tensors.
- run(x: torch.Tensor, weights: torch.Tensor, batch_size: int, weight_column_major: bool, seg_lens: torch.Tensor | None = None, seg_indptr: torch.Tensor | None = None, weight_indices: torch.Tensor | None = None) torch.Tensor #
Run the segment GEMM kernel.
Compute the matrix multiplication between a batch of input tensor (with variable number of rows, but fixed number of columns) and a batch of weight tensor with fixed number of rows and columns:
\[y[i] = x[i] \times W[i]\]if
weight_indices
is provided, we will select the weight tensor based on the indices in theweight_indices
tensor:\[y[i] = x[i] \times W[\text{weight_indices}[i]]\]We use Ragged Tensor to represent the input tensor
x
and the output tensory
, and each x[i] is a segment of the concatenated tensor. Please see Ragged Tensor tutorial for more details. We use aseg_len
orseg_indptr
tensor (either would work) to indicate the start and end of each segment, where theseg_indptr
is the cumulative sum of theseg_lens
tensor (with an additional 0 at the beginning):\[\text{seg_indptr}[i] = \sum_{j=0}^{i-1} \text{seg_lens}[j], \quad \text{seg_indptr}[0] = 0\]- If
seg_lens
is provided, thenx
has shape(sum(seg_lens), d_in)
andy
has shape (sum(seg_lens), d_out)
, whered_in
is the number of columns of the input tensor andd_out
is the number of columns of the output tensor.
- If
- If
seg_indptr
is provided, thenx
has shape(seg_indptr[-1], d_in)
andy
has shape (seg_indptr[-1], d_out)
.
- If
- Parameters:
x (torch.Tensor) – The input tensor with shape
(sum(seg_lens), d_in)
.weights (torch.Tensor) – The 3D weight tensor with shape
(num_weights, d_in, d_out)
ifweight_column_major
isFalse
, or(num_weights, d_out, d_in)
ifweight_column_major
isTrue
.batch_size (int) – The number of segments.
weight_column_major (bool) – Whether the weight tensor is column major.
seg_lens (Optional[torch.Tensor]) – The length of each segment, with shape
(batch_size,)
, expects a 1D tensor of dtypetorch.int64
.seg_indptr (Optional[torch.Tensor]) – The indptr of the segments, with shape
(batch_size + 1,)
, expects a 1D tensor of dtypetorch.int64
. If this is provided, thenseg_lens
will be ignored, otherwiseseg_indptr
will be computed internally fromseg_lens
.weight_indices (Optional[torch.Tensor]) – The indices of the weight tensor to be selected for each segment, with shape
(batch_size,)
. Expects a 1D tensor of dtypetorch.int64
. If this is provided, then the weight tensor will be selected based on the indices in this tensor.
- Returns:
The output tensor with shape
(sum(seg_lens), d_out)
.- Return type:
torch.Tensor