flashinfer.gemm#

This module provides a set of GEMM operations.

FP8 Batch GEMM#

bmm_fp8(A, B, A_scale, B_scale, dtype[, out])

BMM FP8

Grouped GEMM#

class flashinfer.gemm.SegmentGEMMWrapper(float_workspace_buffer: torch.Tensor)#

Wrapper for segment GEMM kernels.

Example

>>> import torch
>>> from flashinfer import SegmentGEMMWrapper
>>> # create a 1MB workspace buffer
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
>>> segment_gemm = SegmentGEMMWrapper(workspace_buffer)
>>> seq_lens = torch.tensor([1, 2, 3, 4], dtype=torch.int64, device="cuda")
>>> # create packed input tensor (10 = 1 + 2 + 3 + 4)
>>> x = torch.randn(10, 128, device="cuda", dtype=torch.float16)
>>> # create weight tensor with 4 weights, each with 128 input and 256 output channels, column major
>>> weights = torch.randn(4, 256, 128, device="cuda", dtype=torch.float16)
>>> # compute the segment GEMM
>>> y = segment_gemm.run(x, weights, 4, True, seg_lens=seq_lens)
>>> y.shape
torch.Size([10, 256])
>>> y_ref_0 = torch.matmul(x[:1], weights[0].t())
>>> torch.allclose(y[:1], y_ref_0)
True
>>> y_ref_1 = torch.matmul(x[1:3], weights[1].t())
>>> torch.allclose(y[1:3], y_ref_1)
True
>>> y_ref_2 = torch.matmul(x[3:6], weights[2].t())
>>> torch.allclose(y[3:6], y_ref_2)
True
>>> y_ref_3 = torch.matmul(x[6:], weights[3].t())
>>> torch.allclose(y[6:], y_ref_3)
True
>>>
>>> # another example with weight indices
>>> weight_indices = torch.tensor([0, 1, 0, 1], dtype=torch.int64, device="cuda")
>>> y = segment_gemm.run(x, weights, 4, True, seg_lens=seq_lens, weight_indices=weight_indices)
>>> y.shape
torch.Size([10, 256])
>>> y_ref_0 = torch.matmul(x[:1], weights[0].t())
>>> torch.allclose(y[:1], y_ref_0)
True
>>> y_ref_1 = torch.matmul(x[1:3], weights[1].t())
>>> torch.allclose(y[1:3], y_ref_1)
True
>>> y_ref_2 = torch.matmul(x[3:6], weights[0].t())
>>> torch.allclose(y[3:6], y_ref_2)
True
>>> y_ref_3 = torch.matmul(x[6:], weights[1].t())
>>> torch.allclose(y[6:], y_ref_3)
True
__init__(float_workspace_buffer: torch.Tensor) None#

Initialize the wrapper.

Parameters:

float_workspace_buffer (torch.Tensor) – The workspace buffer for the kernels, we use it for storing intermediate results in cutlass segment GEMM kernels. Encouraged size is 128MB.

reset_workspace_buffer(float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor) None#

Reset the workspace buffer.

Parameters:
  • float_workspace_buffer (torch.Tensor) – The new float workspace buffer for the kernels.

  • int_workspace_buffer (torch.Tensor) – The new int workspace buffer for the kernels.

run(x: torch.Tensor, weights: torch.Tensor, batch_size: int, weight_column_major: bool, seg_lens: torch.Tensor | None = None, seg_indptr: torch.Tensor | None = None, weight_indices: torch.Tensor | None = None) torch.Tensor#

Run the segment GEMM kernel.

Compute the matrix multiplication between a batch of input tensor (with variable number of rows, but fixed number of columns) and a batch of weight tensor with fixed number of rows and columns:

\[y[i] = x[i] \times W[i]\]

if weight_indices is provided, we will select the weight tensor based on the indices in the weight_indices tensor:

\[y[i] = x[i] \times W[\text{weight_indices}[i]]\]

We use Ragged Tensor to represent the input tensor x and the output tensor y, and each x[i] is a segment of the concatenated tensor. Please see Ragged Tensor tutorial for more details. We use a seg_len or seg_indptr tensor (either would work) to indicate the start and end of each segment, where the seg_indptr is the cumulative sum of the seg_lens tensor (with an additional 0 at the beginning):

\[\text{seg_indptr}[i] = \sum_{j=0}^{i-1} \text{seg_lens}[j], \quad \text{seg_indptr}[0] = 0\]
  • If seg_lens is provided, then x has shape (sum(seg_lens), d_in) and y has shape

    (sum(seg_lens), d_out), where d_in is the number of columns of the input tensor and d_out is the number of columns of the output tensor.

  • If seg_indptr is provided, then x has shape (seg_indptr[-1], d_in) and y has shape

    (seg_indptr[-1], d_out).

Parameters:
  • x (torch.Tensor) – The input tensor with shape (sum(seg_lens), d_in).

  • weights (torch.Tensor) – The 3D weight tensor with shape (num_weights, d_in, d_out) if weight_column_major is False, or (num_weights, d_out, d_in) if weight_column_major is True.

  • batch_size (int) – The number of segments.

  • weight_column_major (bool) – Whether the weight tensor is column major.

  • seg_lens (Optional[torch.Tensor]) – The length of each segment, with shape (batch_size,), expects a 1D tensor of dtype torch.int64.

  • seg_indptr (Optional[torch.Tensor]) – The indptr of the segments, with shape (batch_size + 1,), expects a 1D tensor of dtype torch.int64. If this is provided, then seg_lens will be ignored, otherwise seg_indptr will be computed internally from seg_lens.

  • weight_indices (Optional[torch.Tensor]) – The indices of the weight tensor to be selected for each segment, with shape (batch_size,). Expects a 1D tensor of dtype torch.int64. If this is provided, then the weight tensor will be selected based on the indices in this tensor.

Returns:

The output tensor with shape (sum(seg_lens), d_out).

Return type:

torch.Tensor