flashinfer.gemm.tinygemm_bf16

flashinfer.gemm.tinygemm_bf16(input: Tensor, weight: Tensor, out: Tensor, bias: Tensor | None = None, use_pdl: bool = False) None

SM90+ optimized small GEMM: out = input @ weight.T + bias (equivalent to F.linear).

A latency-optimized, warp-specialized GEMM designed for tiny batch sizes (ideally 1-8 rows, where a single TILE_N=8 tile covers the entire batch dimension) using Ampere-style HMMA instructions. Uses TMA for async bulk data loads and mma.sync.aligned.m16n8k16 tensor-core instructions with BF16 input/weight/bias/output and FP32 internal accumulation. The warp-specialized design (384 threads: 4 compute + 8 DMA warps) with 16 pipeline stages and 4x stage unroll trades off peak throughput in favor of minimal latency. Adapted from the TensorRT-LLM tinygemm2 kernel.

Parameters:
  • input (torch.Tensor) – Input activations of shape (batch_size, input_features). Must be bfloat16, contiguous. input_features must be a multiple of 64.

  • weight (torch.Tensor) – Weight matrix of shape (output_features, input_features). Must be bfloat16, contiguous (row-major). output_features must be a multiple of 16.

  • out (torch.Tensor) – Pre-allocated output tensor of shape (batch_size, output_features). Must be bfloat16, contiguous. Mutated in place.

  • bias (Optional[torch.Tensor]) – Optional bias vector of shape (output_features,). Must be bfloat16, contiguous. If None, zero bias is used.

  • use_pdl (bool) – Enable Programmatic Dependent Launch (stream serialization). When True, the kernel uses cudaGridDependencySynchronize() to overlap DMA with the preceding kernel’s compute. Only enable when ALL preceding stream operations also use PDL, otherwise the kernel hangs. Defaults to False.

Notes

Requires SM90+ (Hopper or newer). Raises ValueError if tensor dimensions, dtypes, or alignment constraints are violated.