flashinfer.gemm.tgv_gemm_sm100

flashinfer.gemm.tgv_gemm_sm100(a: Tensor, b: Tensor, bias: Tensor, pdl: bool = False, out: Tensor | None = None) Tensor

Perform TGV GEMM on SM100 architecture with automatic dtype detection.

Computes out = a @ b + bias. Both a and b must share the same floating-point dtype (torch.bfloat16 or torch.float16).

Parameters:
  • a (torch.Tensor) – First input tensor of shape (M, K) in row-major layout.

  • b (torch.Tensor) – Second input tensor of shape (K, N) in column-major layout (transposed from the typical PyTorch row-major convention).

  • bias (torch.Tensor) – Bias tensor of shape (N,) to add to each row of a @ b.

  • pdl (bool) – Whether to use PDL (Programmatic Dependent Launch). Defaults to False.

  • out (Optional[torch.Tensor]) – Pre-allocated output tensor of shape (M, N). If None, a new tensor is allocated.

Returns:

Output tensor of shape (M, N) in row-major layout.

Return type:

torch.Tensor

Notes

Requires SM100 or SM103 architecture. Supported dtypes are torch.bfloat16 and torch.float16.