flashinfer.gemm.tgv_gemm_sm100¶
- flashinfer.gemm.tgv_gemm_sm100(a: Tensor, b: Tensor, bias: Tensor, pdl: bool = False, out: Tensor | None = None) Tensor¶
Perform TGV GEMM on SM100 architecture with automatic dtype detection.
Computes
out = a @ b + bias. Bothaandbmust share the same floating-point dtype (torch.bfloat16ortorch.float16).- Parameters:
a (torch.Tensor) – First input tensor of shape
(M, K)in row-major layout.b (torch.Tensor) – Second input tensor of shape
(K, N)in column-major layout (transposed from the typical PyTorch row-major convention).bias (torch.Tensor) – Bias tensor of shape
(N,)to add to each row ofa @ b.pdl (bool) – Whether to use PDL (Programmatic Dependent Launch). Defaults to
False.out (Optional[torch.Tensor]) – Pre-allocated output tensor of shape
(M, N). IfNone, a new tensor is allocated.
- Returns:
Output tensor of shape
(M, N)in row-major layout.- Return type:
torch.Tensor
Notes
Requires SM100 or SM103 architecture. Supported dtypes are
torch.bfloat16andtorch.float16.