flashinfer.gemm.tinygemm_bf16¶
- flashinfer.gemm.tinygemm_bf16(input: Tensor, weight: Tensor, out: Tensor, bias: Tensor | None = None, use_pdl: bool = False) None¶
SM90+ optimized small GEMM:
out = input @ weight.T + bias(equivalent to F.linear).A latency-optimized, warp-specialized GEMM designed for tiny batch sizes (ideally 1-8 rows, where a single
TILE_N=8tile covers the entire batch dimension) using Ampere-style HMMA instructions. Uses TMA for async bulk data loads andmma.sync.aligned.m16n8k16tensor-core instructions with BF16 input/weight/bias/output and FP32 internal accumulation. The warp-specialized design (384 threads: 4 compute + 8 DMA warps) with 16 pipeline stages and 4x stage unroll trades off peak throughput in favor of minimal latency. Adapted from the TensorRT-LLMtinygemm2kernel.- Parameters:
input (torch.Tensor) – Input activations of shape
(batch_size, input_features). Must be bfloat16, contiguous.input_featuresmust be a multiple of 64.weight (torch.Tensor) – Weight matrix of shape
(output_features, input_features). Must be bfloat16, contiguous (row-major).output_featuresmust be a multiple of 16.out (torch.Tensor) – Pre-allocated output tensor of shape
(batch_size, output_features). Must be bfloat16, contiguous. Mutated in place.bias (Optional[torch.Tensor]) – Optional bias vector of shape
(output_features,). Must be bfloat16, contiguous. IfNone, zero bias is used.use_pdl (bool) – Enable Programmatic Dependent Launch (stream serialization). When
True, the kernel usescudaGridDependencySynchronize()to overlap DMA with the preceding kernel’s compute. Only enable when ALL preceding stream operations also use PDL, otherwise the kernel hangs. Defaults toFalse.
Notes
Requires SM90+ (Hopper or newer). Raises
ValueErrorif tensor dimensions, dtypes, or alignment constraints are violated.