flashinfer.gemm#

This module provides a set of GEMM operations.

FP8 Batch GEMM#

bmm_fp8(A, B, A_scale, B_scale, dtype[, out])

BMM FP8

Grouped GEMM#

class flashinfer.gemm.SegmentGEMMWrapper(workspace_buffer: torch.Tensor)#

Wrapper for segment GEMM kernels.

Example

>>> import torch
>>> from flashinfer import SegmentGEMMWrapper
>>> # create a 1MB workspace buffer
>>> workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
>>> segment_gemm = SegmentGEMMWrapper(workspace_buffer)
>>> seq_lens = torch.tensor([1, 2, 3, 4], dtype=torch.int64, device="cuda")
>>> # create packed input tensor (10 = 1 + 2 + 3 + 4)
>>> x = torch.randn(10, 128, device="cuda", dtype=torch.float16)
>>> # create weight tensor with 4 weights, each with 128 input and 256 output channels, column major
>>> weights = torch.randn(4, 256, 128, device="cuda", dtype=torch.float16)
>>> # compute the segment GEMM
>>> y = segment_gemm.run(x, weights, 4, True, seg_lens=seq_lens)
>>> y.shape
torch.Size([10, 256])
>>> y_ref_0 = torch.matmul(x[:1], weights[0].t())
>>> torch.allclose(y[:1], y_ref_0)
True
>>> y_ref_1 = torch.matmul(x[1:3], weights[1].t())
>>> torch.allclose(y[1:3], y_ref_1)
True
>>> y_ref_2 = torch.matmul(x[3:6], weights[2].t())
>>> torch.allclose(y[3:6], y_ref_2)
True
>>> y_ref_3 = torch.matmul(x[6:], weights[3].t())
>>> torch.allclose(y[6:], y_ref_3)
True
>>>
>>> # another example with weight indices
>>> weight_indices = torch.tensor([0, 1, 0, 1], dtype=torch.int64, device="cuda")
>>> y = segment_gemm.run(x, weights, 4, True, seg_lens=seq_lens, weight_indices=weight_indices)
>>> y.shape
torch.Size([10, 256])
>>> y_ref_0 = torch.matmul(x[:1], weights[0].t())
>>> torch.allclose(y[:1], y_ref_0)
True
>>> y_ref_1 = torch.matmul(x[1:3], weights[1].t())
>>> torch.allclose(y[1:3], y_ref_1)
True
>>> y_ref_2 = torch.matmul(x[3:6], weights[0].t())
>>> torch.allclose(y[3:6], y_ref_2)
True
>>> y_ref_3 = torch.matmul(x[6:], weights[1].t())
>>> torch.allclose(y[6:], y_ref_3)
True
__init__(workspace_buffer: torch.Tensor) None#

Initialize the wrapper.

Parameters:

workspace_buffer (torch.Tensor) – The workspace buffer for the kernels, we use it to store the metadata for the segment GEMM whose size is proportional to the number of segments (batch size), 1MB workspace is enough for most cases.

reset_workspace_buffer(new_workspace_buffer: torch.Tensor) None#

Reset the workspace buffer.

Parameters:

new_workspace_buffer (torch.Tensor) – The new workspace buffer, the device of the new workspace buffer should be the same as the device of the input tensors.

run(x: torch.Tensor, weights: torch.Tensor, batch_size: int, weight_column_major: bool, seg_lens: torch.Tensor | None = None, seg_indptr: torch.Tensor | None = None, weight_indices: torch.Tensor | None = None) torch.Tensor#

Run the segment GEMM kernel.

Compute the matrix multiplication between a batch of input tensor (with variable number of rows, but fixed number of columns) and a batch of weight tensor with fixed number of rows and columns:

\[y[i] = x[i] \times W[i]\]

if weight_indices is provided, we will select the weight tensor based on the indices in the weight_indices tensor:

\[y[i] = x[i] \times W[\text{weight_indices}[i]]\]

We use Ragged Tensor to represent the input tensor x and the output tensor y, and each x[i] is a segment of the concatenated tensor. Please see Ragged Tensor tutorial for more details. We use a seg_len or seg_indptr tensor (either would work) to indicate the start and end of each segment, where the seg_indptr is the cumulative sum of the seg_lens tensor (with an additional 0 at the beginning):

\[\text{seg_indptr}[i] = \sum_{j=0}^{i-1} \text{seg_lens}[j], \quad \text{seg_indptr}[0] = 0\]
  • If seg_lens is provided, then x has shape (sum(seg_lens), d_in) and y has shape

    (sum(seg_lens), d_out), where d_in is the number of columns of the input tensor and d_out is the number of columns of the output tensor.

  • If seg_indptr is provided, then x has shape (seg_indptr[-1], d_in) and y has shape

    (seg_indptr[-1], d_out).

Parameters:
  • x (torch.Tensor) – The input tensor with shape (sum(seg_lens), d_in).

  • weights (torch.Tensor) – The 3D weight tensor with shape (num_weights, d_in, d_out) if weight_column_major is False, or (num_weights, d_out, d_in) if weight_column_major is True.

  • batch_size (int) – The number of segments.

  • weight_column_major (bool) – Whether the weight tensor is column major.

  • seg_lens (Optional[torch.Tensor]) – The length of each segment, with shape (batch_size,), expects a 1D tensor of dtype torch.int64.

  • seg_indptr (Optional[torch.Tensor]) – The indptr of the segments, with shape (batch_size + 1,), expects a 1D tensor of dtype torch.int64. If this is provided, then seg_lens will be ignored, otherwise seg_indptr will be computed internally from seg_lens.

  • weight_indices (Optional[torch.Tensor]) – The indices of the weight tensor to be selected for each segment, with shape (batch_size,). Expects a 1D tensor of dtype torch.int64. If this is provided, then the weight tensor will be selected based on the indices in this tensor.

Returns:

The output tensor with shape (sum(seg_lens), d_out).

Return type:

torch.Tensor